diff --git a/.github/workflows/release-notify-wechat.yml b/.github/workflows/release-notify-wechat.yml new file mode 100644 index 00000000..935d84d5 --- /dev/null +++ b/.github/workflows/release-notify-wechat.yml @@ -0,0 +1,164 @@ +name: Release Notify Workflow + +on: + pull_request: + types: [closed] + +jobs: + notify: + if: > + github.event.pull_request.merged == true && + startsWith(github.event.pull_request.base.ref, 'release') + runs-on: ubuntu-latest + + steps: + # 防止 GitHub HEAD 未同步 + - run: sleep 3 + + # 1️⃣ 获取分支 HEAD + - name: Get HEAD + id: head + run: | + HEAD_SHA=$(curl -s \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \ + | jq -r '.object.sha') + echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT + + # 2️⃣ 判断是否最终PR + - name: Check Latest + id: check + run: | + if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then + echo "ok=true" >> $GITHUB_OUTPUT + else + echo "ok=false" >> $GITHUB_OUTPUT + fi + + # 3️⃣ 尝试从 PR body 提取 Sourcery 摘要 + - name: Extract Sourcery Summary + if: steps.check.outputs.ok == 'true' + id: sourcery + env: + PR_BODY: ${{ github.event.pull_request.body }} + run: | + python3 << 'PYEOF' + import os, re + + body = os.environ.get("PR_BODY", "") or "" + match = re.search( + r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)", + body, + re.DOTALL + ) + + if match: + summary = match.group(1).strip() + found = "true" + else: + summary = "" + found = "false" + + with open("sourcery_summary.txt", "w", encoding="utf-8") as f: + f.write(summary) + + with open(os.environ["GITHUB_OUTPUT"], "a") as gh: + gh.write(f"found={found}\n") + gh.write("summary< commits.txt + + - name: AI Summary (Qwen Fallback) + if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false' + id: qwen + env: + DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }} + run: | + python3 << 'PYEOF' + import json, os, urllib.request + + with open("commits.txt", "r") as f: + commits = f.read().strip() + + prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits + payload = {"model": "qwen-plus", "input": {"prompt": prompt}} + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + req = urllib.request.Request( + "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", + data=data, + headers={ + "Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"], + "Content-Type": "application/json" + } + ) + resp = urllib.request.urlopen(req) + result = json.loads(resp.read().decode()) + summary = result.get("output", {}).get("text", "AI 摘要生成失败") + + with open(os.environ["GITHUB_OUTPUT"], "a") as gh: + gh.write("summary< � **分支**: " + os.environ["BRANCH"] + "\n" + "> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n" + "> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n" + "> 🔢 **PR编号**: #" + pr_number + "\n" + "> 🔖 **Commit**: " + short_sha + "\n\n" + "### 🧠 " + label + "\n" + + summary + "\n\n" + "---\n" + "🔗 [查看PR详情](" + os.environ["PR_URL"] + ")" + ) + payload = {"msgtype": "markdown", "markdown": {"content": content}} + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + req = urllib.request.Request( + os.environ["WECHAT_WEBHOOK"], + data=data, + headers={"Content-Type": "application/json"} + ) + resp = urllib.request.urlopen(req) + print(resp.read().decode()) + PYEOF diff --git a/.gitignore b/.gitignore index 0ec6822c..a1896da7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ time.log celerybeat-schedule.db search_results.json redbear-mem-metrics/ +redbear-mem-benchmark/ pitch-deck/ api/migrations/versions diff --git a/api/app/config/default_free_plan.py b/api/app/config/default_free_plan.py new file mode 100644 index 00000000..409b4f7b --- /dev/null +++ b/api/app/config/default_free_plan.py @@ -0,0 +1,77 @@ +""" +社区版默认免费套餐配置 +当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底 + +可通过环境变量覆盖配额配置,格式:QUOTA_ +例如:QUOTA_END_USER_QUOTA=100 +""" + +import os + + +def _get_quota_from_env(): + """从环境变量获取配额配置""" + quota_keys = [ + "workspace_quota", + "skill_quota", + "app_quota", + "knowledge_capacity_quota", + "memory_engine_quota", + "end_user_quota", + "ontology_project_quota", + "model_quota", + "api_ops_rate_limit", + ] + quotas = {} + for key in quota_keys: + env_key = f"QUOTA_{key.upper()}" + env_value = os.getenv(env_key) + if env_value is not None: + try: + quotas[key] = float(env_value) if '.' in env_value else int(env_value) + except ValueError: + pass + return quotas + + +def _build_default_free_plan(): + """构建默认免费套餐配置""" + base = { + "name": "记忆体验版", + "name_en": "Memory Experience", + "category": "saas_personal", + "tier_level": 0, + "version": "1.0", + "status": True, + "price": 0, + "billing_cycle": "permanent_free", + "core_value": "感受永久记忆", + "core_value_en": "Experience Permanent Memory", + "tech_support": "社群交流", + "tech_support_en": "Community Support", + "sla_compliance": "无", + "sla_compliance_en": "None", + "page_customization": "无", + "page_customization_en": "None", + "theme_color": "#64748B", + "quotas": { + "workspace_quota": 1, + "skill_quota": 5, + "app_quota": 2, + "knowledge_capacity_quota": 0.3, + "memory_engine_quota": 1, + "end_user_quota": 1, + "ontology_project_quota": 3, + "model_quota": 1, + "api_ops_rate_limit": 50, + }, + } + + env_quotas = _get_quota_from_env() + if env_quotas: + base["quotas"].update(env_quotas) + + return base + + +DEFAULT_FREE_PLAN = _build_default_free_plan() diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 50e9e0b0..e9417d68 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -47,7 +47,8 @@ from . import ( user_memory_controllers, workspace_controller, ontology_controller, - skill_controller + skill_controller, + tenant_subscription_controller, ) # 创建管理端 API 路由器 @@ -98,5 +99,7 @@ manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) manager_router.include_router(i18n_controller.router) +manager_router.include_router(tenant_subscription_controller.router) +manager_router.include_router(tenant_subscription_controller.public_router) __all__ = ["manager_router"] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index db3c7536..3d97f2a2 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService from app.services.workflow_import_service import WorkflowImportService from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.app_dsl_service import AppDslService +from app.core.quota_stub import check_app_quota router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -35,6 +36,7 @@ logger = get_business_logger() @router.post("", summary="创建应用(可选创建 Agent 配置)") @cur_workspace_access_guard() +@check_app_quota def create_app( payload: app_schema.AppCreate, db: Session = Depends(get_db), @@ -269,6 +271,19 @@ def update_agent_config( return success(data=app_schema.AgentConfig.model_validate(cfg)) +@router.get("/{app_id}/model/parameters/default", summary="获取 Agent 模型参数默认配置") +@cur_workspace_access_guard() +def get_agent_model_parameters( + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + workspace_id = current_user.current_workspace_id + service = AppService(db) + model_parameters = service.get_default_model_parameters(app_id=app_id) + return success(data=model_parameters, msg="获取 Agent 模型参数默认配置") + + @router.get("/{app_id}/config", summary="获取 Agent 配置") @cur_workspace_access_guard() def get_agent_config( @@ -1250,9 +1265,11 @@ async def export_app( async def import_app( file: UploadFile = File(...), db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + app_id: Optional[str] = Form(None), ): """从 YAML 文件导入 agent / multi_agent / workflow 应用。 + 传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。 跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。 """ if not file.filename.lower().endswith((".yaml", ".yml")): @@ -1263,13 +1280,15 @@ async def import_app( if not dsl or "app" not in dsl: return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST) - new_app, warnings = AppDslService(db).import_dsl( + target_app_id = uuid.UUID(app_id) if app_id else None + result_app, warnings = AppDslService(db).import_dsl( dsl=dsl, workspace_id=current_user.current_workspace_id, tenant_id=current_user.tenant_id, user_id=current_user.id, + app_id=target_app_id, ) return success( - data={"app": app_schema.App.model_validate(new_app), "warnings": warnings}, + data={"app": app_schema.App.model_validate(result_app), "warnings": warnings}, msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "") ) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index b5c0a5ae..cc1f8c98 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -443,10 +443,10 @@ async def retrieve_chunks( match retrieve_data.retrieve_type: case chunk_schema.RetrieveType.PARTICIPLE: rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) - return success(data=rs, msg="retrieval successful") + return success(data=jsonable_encoder(rs), msg="retrieval successful") case chunk_schema.RetrieveType.SEMANTIC: rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) - return success(data=rs, msg="retrieval successful") + return success(data=jsonable_encoder(rs), msg="retrieval successful") case _: rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) diff --git a/api/app/controllers/file_controller.py b/api/app/controllers/file_controller.py index f7bd0e7a..6f8b1b97 100644 --- a/api/app/controllers/file_controller.py +++ b/api/app/controllers/file_controller.py @@ -19,6 +19,7 @@ from app.models.user_model import User from app.schemas import file_schema, document_schema from app.schemas.response_schema import ApiResponse from app.services import file_service, document_service +from app.core.quota_stub import check_knowledge_capacity_quota # Obtain a dedicated API logger @@ -131,6 +132,7 @@ async def create_folder( @router.post("/file", response_model=ApiResponse) +@check_knowledge_capacity_quota async def upload_file( kb_id: uuid.UUID, parent_id: uuid.UUID, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index afda7cce..5cd87647 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -27,6 +27,7 @@ from app.schemas import knowledge_schema from app.schemas.response_schema import ApiResponse from app.services import knowledge_service, document_service from app.services.model_service import ModelConfigService +from app.core.quota_stub import check_knowledge_capacity_quota # Obtain a dedicated API logger api_logger = get_api_logger() @@ -179,6 +180,7 @@ async def get_knowledges( @router.post("/knowledge", response_model=ApiResponse) +@check_knowledge_capacity_quota async def create_knowledge( create_data: knowledge_schema.KnowledgeCreate, db: Session = Depends(get_db), diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 76eed50f..545f8302 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -34,6 +34,7 @@ from app.services.memory_storage_service import ( search_entity, search_statement, ) +from app.core.quota_stub import check_memory_engine_quota from fastapi import APIRouter, Depends, Header from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -76,6 +77,7 @@ async def get_storage_info( @router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@check_memory_engine_quota def create_config( payload: ConfigParamsCreate, current_user: User = Depends(get_current_user), diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 71fd41ad..6105c3d8 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -15,6 +15,7 @@ from app.core.response_utils import success from app.schemas.response_schema import ApiResponse, PageData from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService from app.core.logging_config import get_api_logger +from app.core.quota_stub import check_model_quota, check_model_activation_quota # 获取API专用日志器 api_logger = get_api_logger() @@ -236,6 +237,7 @@ def delete_model_base( @router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse) +@check_model_quota def add_model_from_plaza( model_base_id: uuid.UUID, db: Session = Depends(get_db), @@ -273,6 +275,7 @@ def get_model_by_id( @router.post("", response_model=ApiResponse) +@check_model_quota async def create_model( model_data: model_schema.ModelConfigCreate, db: Session = Depends(get_db), @@ -303,6 +306,7 @@ async def create_model( @router.post("/composite", response_model=ApiResponse) +@check_model_quota async def create_composite_model( model_data: model_schema.CompositeModelCreate, db: Session = Depends(get_db), @@ -329,6 +333,7 @@ async def create_composite_model( @router.put("/composite/{model_id}", response_model=ApiResponse) +@check_model_activation_quota async def update_composite_model( model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, @@ -370,6 +375,7 @@ def delete_composite_model( @router.put("/{model_id}", response_model=ApiResponse) +@check_model_activation_quota def update_model( model_id: uuid.UUID, model_data: model_schema.ModelConfigUpdate, diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index fe6b3598..602ee709 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session +from app.core.quota_stub import check_ontology_project_quota + from app.core.config import settings from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header @@ -163,7 +165,7 @@ def _get_ontology_service( api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - support_thinking="thinking" in (api_key_config.capability or []), + capability=api_key_config.capability, max_retries=3, timeout=60.0 ) @@ -287,6 +289,7 @@ async def extract_ontology( # ==================== 本体场景管理接口 ==================== @router.post("/scene", response_model=ApiResponse) +@check_ontology_project_quota async def create_scene( request: SceneCreateRequest, db: Session = Depends(get_db), diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 80f14cd3..b9fc697c 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -124,10 +124,11 @@ async def get_prompt_opt( skill=data.skill ): # chunk 是 prompt 的增量内容 - yield f"event:message\ndata: {json.dumps(chunk)}\n\n" + yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n" except Exception as e: yield f"event:error\ndata: {json.dumps( - {"error": str(e)} + {"error": str(e)}, + ensure_ascii=False )}\n\n" yield "event:end\ndata: {}\n\n" diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index ddd31071..049535b5 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.core.quota_manager import check_end_user_quota from app.core.response_utils import success, fail from app.db import get_db, get_db_read from app.dependencies import get_share_user_id, ShareTokenData @@ -308,6 +309,7 @@ def get_conversation( "/chat", summary="发送消息(支持流式和非流式)" ) +@check_end_user_quota async def chat( payload: conversation_schema.ChatRequest, share_data: ShareTokenData = Depends(get_share_user_id), diff --git a/api/app/controllers/service/__init__.py b/api/app/controllers/service/__init__.py index 96da0949..52d4b732 100644 --- a/api/app/controllers/service/__init__.py +++ b/api/app/controllers/service/__init__.py @@ -4,7 +4,17 @@ 认证方式: API Key """ from fastapi import APIRouter -from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller + +from . import ( + app_api_controller, + end_user_api_controller, + memory_api_controller, + memory_config_api_controller, + rag_api_chunk_controller, + rag_api_document_controller, + rag_api_file_controller, + rag_api_knowledge_controller, +) # 创建 V1 API 路由器 service_router = APIRouter() @@ -17,5 +27,6 @@ service_router.include_router(rag_api_file_controller.router) service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(memory_api_controller.router) service_router.include_router(end_user_api_controller.router) +service_router.include_router(memory_config_api_controller.router) __all__ = ["service_router"] diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py index df9996c2..1faea6ef 100644 --- a/api/app/controllers/service/end_user_api_controller.py +++ b/api/app/controllers/service/end_user_api_controller.py @@ -5,23 +5,44 @@ import uuid from fastapi import APIRouter, Body, Depends, Request from sqlalchemy.orm import Session +from app.controllers import user_memory_controllers from app.core.api_key_auth import require_api_key from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.core.quota_stub import check_end_user_quota from app.core.response_utils import success from app.db import get_db from app.repositories.end_user_repository import EndUserRepository from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.end_user_info_schema import EndUserInfoUpdate from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse +from app.services import api_key_service from app.services.memory_config_service import MemoryConfigService router = APIRouter(prefix="/end_user", tags=["V1 - End User API"]) logger = get_business_logger() +def _get_current_user(api_key_auth: ApiKeyAuth, db: Session): + """Build a current_user object from API key auth + + Args: + api_key_auth: Validated API key auth info + db: Database session + + Returns: + User object with current_workspace_id set + """ + api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id) + current_user = api_key.creator + current_user.current_workspace_id = api_key_auth.workspace_id + return current_user + + @router.post("/create") @require_api_key(scopes=["memory"]) +@check_end_user_quota async def create_end_user( request: Request, api_key_auth: ApiKeyAuth = None, @@ -37,6 +58,7 @@ async def create_end_user( Optionally accepts a memory_config_id to connect the end user to a specific memory configuration. If not provided, falls back to the workspace default config. + Optionally accepts an app_id to bind the end user to a specific app. """ body = await request.json() payload = CreateEndUserRequest(**body) @@ -71,14 +93,26 @@ async def create_end_user( else: logger.warning(f"No default memory config found for workspace: {workspace_id}") + # Resolve app_id: explicit from payload, otherwise None + app_id = None + if payload.app_id: + try: + app_id = uuid.UUID(payload.app_id) + except ValueError: + raise BusinessException( + f"Invalid app_id format: {payload.app_id}", + BizCode.INVALID_PARAMETER + ) + end_user_repo = EndUserRepository(db) end_user = end_user_repo.get_or_create_end_user_with_config( - app_id=api_key_auth.resource_id, + app_id=app_id, workspace_id=workspace_id, other_id=payload.other_id, memory_config_id=memory_config_id, + other_name=payload.other_name, ) - + end_user.other_name = payload.other_name logger.info(f"End user ready: {end_user.id}") result = { @@ -90,3 +124,50 @@ async def create_end_user( } return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") + + +@router.get("/info") +@require_api_key(scopes=["memory"]) +async def get_end_user_info( + request: Request, + end_user_id: str, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get end user info. + + Retrieves the info record (aliases, meta_data, etc.) for the specified end user. + Delegates to the manager-side controller for shared logic. + """ + current_user = _get_current_user(api_key_auth, db) + return await user_memory_controllers.get_end_user_info( + end_user_id=end_user_id, + current_user=current_user, + db=db, + ) + + +@router.post("/info/update") +@require_api_key(scopes=["memory"]) +async def update_end_user_info( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update end user info. + + Updates the info record (other_name, aliases, meta_data) for the specified end user. + Delegates to the manager-side controller for shared logic. + """ + body = await request.json() + payload = EndUserInfoUpdate(**body) + + current_user = _get_current_user(api_key_auth, db) + return await user_memory_controllers.update_end_user_info( + info_update=payload, + current_user=current_user, + db=db, + ) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index dc5e0408..313781d2 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -1,53 +1,83 @@ """Memory 服务接口 - 基于 API Key 认证""" +from fastapi import APIRouter, Body, Depends, Query, Request +from sqlalchemy.orm import Session + from app.core.api_key_auth import require_api_key from app.core.logging_config import get_business_logger +from app.core.quota_stub import check_end_user_quota from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.memory_api_schema import ( - CreateEndUserRequest, - CreateEndUserResponse, - ListConfigsResponse, MemoryReadRequest, MemoryReadResponse, + MemoryReadSyncResponse, MemoryWriteRequest, MemoryWriteResponse, + MemoryWriteSyncResponse, ) from app.services.memory_api_service import MemoryAPIService -from fastapi import APIRouter, Body, Depends, Request -from sqlalchemy.orm import Session router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) logger = get_business_logger() +def _sanitize_task_result(result: dict) -> dict: + """Make Celery task result JSON-serializable. + + Converts UUID and other non-serializable values to strings. + + Args: + result: Raw task result dict from task_service + + Returns: + JSON-safe dict + """ + import uuid as _uuid + from datetime import datetime + + def _convert(obj): + if isinstance(obj, dict): + return {k: _convert(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_convert(i) for i in obj] + if isinstance(obj, _uuid.UUID): + return str(obj) + if isinstance(obj, datetime): + return obj.isoformat() + return obj + + return _convert(result) + + @router.get("") async def get_memory_info(): """获取记忆服务信息(占位)""" return success(data={}, msg="Memory API - Coming Soon") -@router.post("/write_api_service") +@router.post("/write") @require_api_key(scopes=["memory"]) -async def write_memory_api_service( +async def write_memory( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), message: str = Body(..., description="Message content"), ): """ - Write memory to storage. - - Stores memory content for the specified end user using the Memory API Service. + Submit a memory write task. + + Validates the end user, then dispatches the write to a Celery background task + with per-user fair locking. Returns a task_id for status polling. """ body = await request.json() payload = MemoryWriteRequest(**body) logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") - + memory_api_service = MemoryAPIService(db) - - result = await memory_api_service.write_memory( + + result = memory_api_service.write_memory( workspace_id=api_key_auth.workspace_id, end_user_id=payload.end_user_id, message=payload.message, @@ -55,31 +85,53 @@ async def write_memory_api_service( storage_type=payload.storage_type, user_rag_memory_id=payload.user_rag_memory_id, ) - - logger.info(f"Memory write successful for end_user: {payload.end_user_id}") - return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully") + + logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") -@router.post("/read_api_service") +@router.get("/write/status") @require_api_key(scopes=["memory"]) -async def read_memory_api_service( +async def get_write_task_status( + request: Request, + task_id: str = Query(..., description="Celery task ID"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Check the status of a memory write task. + + Returns the current status and result (if completed) of a previously submitted write task. + """ + logger.info(f"Write task status check - task_id: {task_id}") + + from app.services.task_service import get_task_memory_write_result + result = get_task_memory_write_result(task_id) + + return success(data=_sanitize_task_result(result), msg="Task status retrieved") + + +@router.post("/read") +@require_api_key(scopes=["memory"]) +async def read_memory( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), message: str = Body(..., description="Query message"), ): """ - Read memory from storage. - - Queries and retrieves memories for the specified end user with context-aware responses. + Submit a memory read task. + + Validates the end user, then dispatches the read to a Celery background task. + Returns a task_id for status polling. """ body = await request.json() payload = MemoryReadRequest(**body) logger.info(f"Memory read request - end_user_id: {payload.end_user_id}") - + memory_api_service = MemoryAPIService(db) - - result = await memory_api_service.read_memory( + + result = memory_api_service.read_memory( workspace_id=api_key_auth.workspace_id, end_user_id=payload.end_user_id, message=payload.message, @@ -88,58 +140,95 @@ async def read_memory_api_service( storage_type=payload.storage_type, user_rag_memory_id=payload.user_rag_memory_id, ) - - logger.info(f"Memory read successful for end_user: {payload.end_user_id}") - return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully") + + logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted") -@router.get("/configs") +@router.get("/read/status") @require_api_key(scopes=["memory"]) -async def list_memory_configs( +async def get_read_task_status( request: Request, + task_id: str = Query(..., description="Celery task ID"), api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), ): """ - List all memory configs for the workspace. - - Returns all available memory configurations associated with the authorized workspace. + Check the status of a memory read task. + + Returns the current status and result (if completed) of a previously submitted read task. """ - logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}") + logger.info(f"Read task status check - task_id: {task_id}") - memory_api_service = MemoryAPIService(db) + from app.services.task_service import get_task_memory_read_result + result = get_task_memory_read_result(task_id) - result = memory_api_service.list_memory_configs( - workspace_id=api_key_auth.workspace_id, - ) - - logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") - return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + return success(data=_sanitize_task_result(result), msg="Task status retrieved") -@router.post("/end_users") +@router.post("/write/sync") @require_api_key(scopes=["memory"]) -async def create_end_user( +@check_end_user_quota +async def write_memory_sync( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), + message: str = Body(..., description="Message content"), ): """ - Create an end user. - - Creates a new end user for the authorized workspace. - If an end user with the same other_id already exists, returns the existing one. + Write memory synchronously. + + Blocks until the write completes and returns the result directly. + For async processing with task polling, use /write instead. """ body = await request.json() - payload = CreateEndUserRequest(**body) - logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}") + payload = MemoryWriteRequest(**body) + logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}") memory_api_service = MemoryAPIService(db) - result = memory_api_service.create_end_user( + result = await memory_api_service.write_memory_sync( workspace_id=api_key_auth.workspace_id, - other_id=payload.other_id, + end_user_id=payload.end_user_id, + message=payload.message, + config_id=payload.config_id, + storage_type=payload.storage_type, + user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"End user ready: {result['id']}") - return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") + logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}") + return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully") + + +@router.post("/read/sync") +@require_api_key(scopes=["memory"]) +async def read_memory_sync( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(..., description="Query message"), +): + """ + Read memory synchronously. + + Blocks until the read completes and returns the answer directly. + For async processing with task polling, use /read instead. + """ + body = await request.json() + payload = MemoryReadRequest(**body) + logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}") + + memory_api_service = MemoryAPIService(db) + + result = await memory_api_service.read_memory_sync( + workspace_id=api_key_auth.workspace_id, + end_user_id=payload.end_user_id, + message=payload.message, + search_switch=payload.search_switch, + config_id=payload.config_id, + storage_type=payload.storage_type, + user_rag_memory_id=payload.user_rag_memory_id, + ) + + logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}") + return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully") diff --git a/api/app/controllers/service/memory_config_api_controller.py b/api/app/controllers/service/memory_config_api_controller.py new file mode 100644 index 00000000..1e61e0af --- /dev/null +++ b/api/app/controllers/service/memory_config_api_controller.py @@ -0,0 +1,491 @@ +"""Memory Config 服务接口 - 基于 API Key 认证""" + +from typing import Optional +import uuid + +from fastapi import APIRouter, Body, Depends, Header, Query, Request +from fastapi.encoders import jsonable_encoder +from sqlalchemy.orm import Session + +from app.controllers import memory_storage_controller +from app.controllers import memory_forget_controller +from app.controllers import ontology_controller +from app.controllers import emotion_config_controller +from app.controllers import memory_reflection_controller +from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest +from app.controllers.emotion_config_controller import EmotionConfigUpdate +from app.schemas.memory_reflection_schemas import Memory_Reflection +from app.core.api_key_auth import require_api_key +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.repositories.memory_config_repository import MemoryConfigRepository +from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.memory_api_schema import ( + ConfigUpdateExtractedRequest, + ConfigUpdateRequest, + ListConfigsResponse, + ConfigCreateRequest, + ConfigUpdateForgettingRequest, + EmotionConfigUpdateRequest, + ReflectionConfigUpdateRequest, +) +from app.schemas.memory_storage_schema import ( + ConfigUpdate, + ConfigUpdateExtracted, + ConfigParamsCreate, +) +from app.services import api_key_service +from app.services.memory_api_service import MemoryAPIService +from app.utils.config_utils import resolve_config_id + +router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"]) +logger = get_business_logger() + + +def _get_current_user(api_key_auth: ApiKeyAuth, db: Session): + """Build a current_user object from API key auth + + Args: + api_key_auth: Validated API key auth info + db: Database session + + Returns: + User object with current_workspace_id set + """ + api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id) + current_user = api_key.creator + current_user.current_workspace_id = api_key_auth.workspace_id + return current_user + + +def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session): + """Verify that the config belongs to the workspace. + + Args: + config_id: The ID of the config to verify + workspace_id: The workspace ID tocheck against + db: Database session for querying + Raises: + BusinessException: If the config does not exist or does not belong to the workspace + """ + try: + resolved_id = resolve_config_id(config_id, db) + except ValueError as e: + raise BusinessException( + message=f"Invalid config_id: {e}", + code=BizCode.INVALID_PARAMETER, + ) + config = MemoryConfigRepository.get_by_id(db, resolved_id) + if not config or config.workspace_id != workspace_id: + raise BusinessException( + message="Config not found or access denied", + code=BizCode.MEMORY_CONFIG_NOT_FOUND, + ) + +# @router.get("/configs") +# @require_api_key(scopes=["memory"]) +# async def list_memory_configs( +# request: Request, +# api_key_auth: ApiKeyAuth = None, +# db: Session = Depends(get_db), +# ): +# """ +# List all memory configs for the workspace. + +# Returns all available memory configurations associated with the authorized workspace. +# """ +# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}") + +# memory_api_service = MemoryAPIService(db) + +# result = memory_api_service.list_memory_configs( +# workspace_id=api_key_auth.workspace_id, +# ) + +# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") +# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + +@router.get("/read_all_config") +@require_api_key(scopes=["memory"]) +async def read_all_config( + request:Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + List all memory configs with full details (enhanced version). + + Returns complete config fields for the authorized workspace. + No config_id ownership check needed — results are filtered by workspace. + """ + logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}") + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.read_all_config( + current_user=current_user, + db=db, + ) + +@router.get("/scenes/simple") +@require_api_key(scopes=["memory"]) +async def get_ontology_scenes( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get available ontology scenes for the workspace. + + Returns a simple list of scene_id and scene_name for dropdown selection. + Used before creating a memory config to choose which ontology scene to associate. + """ + logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}") + + current_user = _get_current_user(api_key_auth, db) + + return await ontology_controller.get_scenes_simple( + db=db, + current_user=current_user, + ) + +@router.get("/read_config_extracted") +@require_api_key(scopes=["memory"]) +async def read_config_extracted( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get extraction engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.read_config_extracted( + config_id = config_id, + current_user = current_user, + db = db, + ) + +@router.get("/read_config_forgetting") +@require_api_key(scopes=["memory"]) +async def read_config_forgetting( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get forgetting settings for a specific memory config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + result = await memory_forget_controller.read_forgetting_config( + config_id = config_id, + current_user = current_user, + db = db, + ) + return jsonable_encoder(result) + + + +@router.get("/read_config_emotion") +@require_api_key(scopes=["memory"]) +async def read_config_emotion( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get emotion engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return jsonable_encoder(emotion_config_controller.get_emotion_config( + config_id=config_id, + db=db, + current_user=current_user, + )) + +@router.get("/read_config_reflection") +@require_api_key(scopes=["memory"]) +async def read_config_reflection( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get reflection engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return jsonable_encoder(await memory_reflection_controller.start_reflection_configs( + config_id=config_id, + current_user=current_user, + db=db, + )) + + +@router.post("/create_config") +@require_api_key(scopes=["memory"]) +async def create_memory_config( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), +): + """ + Create a new memory config for the workspace. + + The config will be associated with the workspace of the API Key. + config_name is required, other fields are optional. + """ + body = await request.json() + payload = ConfigCreateRequest(**body) + + logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}") + + # 构造管理端 Schema,workspace_id 从 API Key 注入 + current_user = _get_current_user(api_key_auth, db) + mgmt_payload = ConfigParamsCreate( + config_name=payload.config_name, + config_desc=payload.config_desc or "", + scene_id=payload.scene_id, + llm_id=payload.llm_id, + embedding_id=payload.embedding_id, + rerank_id=payload.rerank_id, + reflection_model_id=payload.reflection_model_id, + emotion_model_id=payload.emotion_model_id, + ) + #将返回数据中UUID序列化处理 + result =memory_storage_controller.create_config( + payload=mgmt_payload, + current_user=current_user, + db=db, + x_language_type=x_language_type, + ) + return jsonable_encoder(result) + +@router.put("/update_config") +@require_api_key(scopes=["memory"]) +async def update_memory_config( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update memory config basic info (name, description, scene). + + Requires API Key with 'memory' scope + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateRequest(**body) + + logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + mgmt_payload = ConfigUpdate( + config_id = payload.config_id, + config_name = payload.config_name, + config_desc = payload.config_desc, + scene_id = payload.scene_id, + ) + + return memory_storage_controller.update_config( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + +@router.put("/update_config_extracted") +@require_api_key(scopes=["memory"]) +async def update_memory_config_extracted( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + update memory config extraction engine config (models, thresholds, chunking, pruning, etc.). + + Requires API Key with 'memory' scope. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateExtractedRequest(**body) + + logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + #校验权限 + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = ConfigUpdateExtracted(**update_fields) + + return memory_storage_controller.update_config_extracted( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + +@router.put("/update_config_forgetting") +@require_api_key(scopes=["memory"]) +async def update_memory_config_forgetting( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + update memory config forgetting settings (forgetting strategy, parameters, etc.). + + Requires API Key with 'memory' scope. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateForgettingRequest(**body) + + logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + #校验权限 + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = ForgettingConfigUpdateRequest(**update_fields) + + #将返回数据中UUID序列化处理 + result = await memory_forget_controller.update_forgetting_config( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + return jsonable_encoder(result) + +@router.put("/update_config_emotion") +@require_api_key(scopes=["memory"]) +async def update_config_emotion( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update emotion engine config (full update). + + All fields except emotion_model_id are required. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = EmotionConfigUpdateRequest(**body) + + logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = EmotionConfigUpdate(**update_fields) + return jsonable_encoder(emotion_config_controller.update_emotion_config( + config=mgmt_payload, + db=db, + current_user=current_user, + )) + +@router.put("/update_config_reflection") +@require_api_key(scopes=["memory"]) +async def update_config_reflection( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update reflection engine config (full update). + + All fields are required. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ReflectionConfigUpdateRequest(**body) + + logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = Memory_Reflection(**update_fields) + + return jsonable_encoder(await memory_reflection_controller.save_reflection_config( + request=mgmt_payload, + current_user=current_user, + db=db, + )) + +@router.delete("/delete_config") +@require_api_key(scopes=["memory"]) +async def delete_memory_config( + config_id: str, + request: Request, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Delete a memory config. + + - Default configs cannot be deleted. + - If end users are connected and force=False, returns a warning. + - If force=True, clears end user references and deletes the config. + + Only configs belonging to the authorized workspace can be deleted. + """ + logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.delete_config( + config_id=config_id, + force=force, + current_user=current_user, + db=db, + ) diff --git a/api/app/controllers/skill_controller.py b/api/app/controllers/skill_controller.py index 6e673679..4ee07c7d 100644 --- a/api/app/controllers/skill_controller.py +++ b/api/app/controllers/skill_controller.py @@ -11,11 +11,13 @@ from app.schemas import skill_schema from app.schemas.response_schema import PageData, PageMeta from app.services.skill_service import SkillService from app.core.response_utils import success +from app.core.quota_stub import check_skill_quota router = APIRouter(prefix="/skills", tags=["Skills"]) @router.post("", summary="创建技能") +@check_skill_quota def create_skill( data: skill_schema.SkillCreate, db: Session = Depends(get_db), diff --git a/api/app/controllers/tenant_subscription_controller.py b/api/app/controllers/tenant_subscription_controller.py new file mode 100644 index 00000000..62edb777 --- /dev/null +++ b/api/app/controllers/tenant_subscription_controller.py @@ -0,0 +1,173 @@ +""" +租户套餐查询接口(普通用户可访问) +""" +import datetime +from typing import Callable, Optional + +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session + +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.db import get_db +from app.dependencies import get_current_user +from app.i18n.dependencies import get_translator +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse + +logger = get_api_logger() + +router = APIRouter(prefix="/tenant", tags=["Tenant"]) +public_router = APIRouter(tags=["Tenant"]) + + +@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息") +async def get_my_tenant_subscription( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + t: Callable = Depends(get_translator), +): + """ + 获取当前登录用户所属租户的有效套餐订阅信息。 + 包含套餐名称、版本、配额、到期时间等。 + """ + try: + from premium.platform_admin.package_plan_service import TenantSubscriptionService + + if not current_user.tenant: + return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户")) + + tenant_id = current_user.tenant.id + svc = TenantSubscriptionService(db) + sub = svc.get_subscription(tenant_id) + + if not sub: + # 无订阅记录时,兜底返回免费套餐信息 + free_plan = svc.plan_repo.get_free_plan() + if not free_plan: + return success(data=None, msg="暂无有效套餐") + return success(data={ + "subscription_id": None, + "tenant_id": str(tenant_id), + "package_plan_id": str(free_plan.id), + "package_version": free_plan.version, + "package_plan": { + "id": str(free_plan.id), + "name": free_plan.name, + "name_en": free_plan.name_en, + "version": free_plan.version, + "category": free_plan.category, + "tier_level": free_plan.tier_level, + "price": float(free_plan.price) if free_plan.price is not None else 0.0, + "billing_cycle": free_plan.billing_cycle, + "core_value": free_plan.core_value, + "core_value_en": free_plan.core_value_en, + "tech_support": free_plan.tech_support, + "tech_support_en": free_plan.tech_support_en, + "sla_compliance": free_plan.sla_compliance, + "sla_compliance_en": free_plan.sla_compliance_en, + "page_customization": free_plan.page_customization, + "page_customization_en": free_plan.page_customization_en, + "theme_color": free_plan.theme_color, + }, + "started_at": None, + "expired_at": None, + "status": "active", + "quotas": free_plan.quotas or {}, + "created_at": int(datetime.datetime.utcnow().timestamp() * 1000), + "updated_at": int(datetime.datetime.utcnow().timestamp() * 1000), + }, msg="免费套餐") + + return success(data=svc.build_response(sub)) + + except ModuleNotFoundError: + # 社区版无 premium 模块,从配置文件读取免费套餐 + if not current_user.tenant: + return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户")) + + from app.config.default_free_plan import DEFAULT_FREE_PLAN + + plan = DEFAULT_FREE_PLAN + response_data = { + "subscription_id": None, + "tenant_id": str(current_user.tenant.id), + "package_plan_id": None, + "package_version": plan["version"], + "package_plan": { + "id": None, + "name": plan["name"], + "name_en": plan.get("name_en"), + "version": plan["version"], + "category": plan["category"], + "tier_level": plan["tier_level"], + "price": float(plan["price"]), + "billing_cycle": plan["billing_cycle"], + "core_value": plan.get("core_value"), + "core_value_en": plan.get("core_value_en"), + "tech_support": plan.get("tech_support"), + "tech_support_en": plan.get("tech_support_en"), + "sla_compliance": plan.get("sla_compliance"), + "sla_compliance_en": plan.get("sla_compliance_en"), + "page_customization": plan.get("page_customization"), + "page_customization_en": plan.get("page_customization_en"), + "theme_color": plan.get("theme_color"), + }, + "started_at": None, + "expired_at": None, + "status": "active", + "quotas": plan["quotas"], + "created_at": int(datetime.datetime.utcnow().timestamp() * 1000), + "updated_at": int(datetime.datetime.utcnow().timestamp() * 1000), + } + return success(data=response_data, msg="社区版免费套餐") + + except Exception as e: + logger.error(f"获取租户套餐信息失败: {e}", exc_info=True) + return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败")) + + +@public_router.get("/package-plans", response_model=ApiResponse, summary="获取套餐列表(公开)") +async def list_package_plans_public( + category: Optional[str] = None, + status: Optional[bool] = None, + search: Optional[str] = None, + db: Session = Depends(get_db), +): + """ + 公开接口,无需鉴权。 + SaaS 版从数据库读取套餐列表;社区版降级返回 default_free_plan.py 中的免费套餐。 + """ + try: + from premium.platform_admin.package_plan_service import PackagePlanService + from premium.platform_admin.package_plan_schema import PackagePlanResponse + svc = PackagePlanService(db) + result = svc.get_list(page=1, size=9999, category=category, status=status, search=search) + return success(data=[PackagePlanResponse.model_validate(p).model_dump(mode="json") for p in result["items"]]) + except ModuleNotFoundError: + from app.config.default_free_plan import DEFAULT_FREE_PLAN + plan = DEFAULT_FREE_PLAN + return success(data=[{ + "id": None, + "name": plan["name"], + "name_en": plan.get("name_en"), + "version": plan["version"], + "category": plan["category"], + "tier_level": plan["tier_level"], + "price": float(plan["price"]), + "billing_cycle": plan["billing_cycle"], + "core_value": plan.get("core_value"), + "core_value_en": plan.get("core_value_en"), + "tech_support": plan.get("tech_support"), + "tech_support_en": plan.get("tech_support_en"), + "sla_compliance": plan.get("sla_compliance"), + "sla_compliance_en": plan.get("sla_compliance_en"), + "page_customization": plan.get("page_customization"), + "page_customization_en": plan.get("page_customization_en"), + "theme_color": plan.get("theme_color"), + "status": plan.get("status", True), + "quotas": plan["quotas"], + }]) + except Exception as e: + logger.error(f"获取套餐列表失败: {e}", exc_info=True) + return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐列表失败")) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index cc16a6b4..5a329165 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -114,11 +114,14 @@ def get_current_user_info( # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限 if current_user.external_source: - from premium.sso.models import SSOSource - source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() - if source and source.permissions: - result_schema.permissions = source.permissions - else: + try: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() + if source and source.permissions: + result_schema.permissions = source.permissions + else: + result_schema.permissions = [] + except ModuleNotFoundError: result_schema.permissions = [] else: result_schema.permissions = ["all"] diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index 6f4a4fa8..47068288 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -35,6 +35,7 @@ from app.schemas.workspace_schema import ( WorkspaceUpdate, ) from app.services import workspace_service +from app.core.quota_stub import check_workspace_quota # 获取API专用日志器 api_logger = get_api_logger() @@ -106,6 +107,7 @@ def get_workspaces( @router.post("", response_model=ApiResponse) +@check_workspace_quota def create_workspace( workspace: WorkspaceCreate, language_type: str = Header(default="zh", alias="X-Language-Type"), diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index ca7172e8..a3d1d308 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -12,7 +12,7 @@ import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from langchain.agents import create_agent -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.tools import BaseTool from langgraph.errors import GraphRecursionError @@ -41,6 +41,7 @@ class LangChainAgent: max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数 deep_thinking: bool = False, # 是否启用深度思考模式 thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算 + json_output: bool = False, # 是否强制 JSON 输出 capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考 ): """初始化 LangChain Agent @@ -64,7 +65,6 @@ class LangChainAgent: self.streaming = streaming self.is_omni = is_omni self.max_tool_consecutive_calls = max_tool_consecutive_calls - self.deep_thinking = deep_thinking and ("thinking" in (capability or [])) # 工具调用计数器:记录每个工具的连续调用次数 self.tool_call_counter: Dict[str, int] = {} @@ -80,6 +80,17 @@ class LangChainAgent: self.system_prompt = system_prompt or "你是一个专业的AI助手" + # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format + # 在 system prompt 中注入 JSON 要求 + from app.models.models_model import ModelProvider + if json_output and ( + (provider.lower() == ModelProvider.DASHSCOPE and not is_omni) + or provider.lower() == ModelProvider.VOLCANO + # 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出 + or bool(tools) + ): + self.system_prompt += "\n请以JSON格式输出。" + logger.debug( f"Agent 迭代次数配置: max_iterations={self.max_iterations}, " f"tool_count={len(self.tools)}, " @@ -87,23 +98,17 @@ class LangChainAgent: f"auto_calculated={max_iterations is None}" ) - # 根据 capability 校验是否真正支持深度思考 - actual_deep_thinking = self.deep_thinking - if deep_thinking and not actual_deep_thinking: - logger.warning( - f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking" - ) - - # 创建 RedBearLLM(支持多提供商) + # 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理 model_config = RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, base_url=api_base, is_omni=is_omni, - deep_thinking=actual_deep_thinking, - thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None, - support_thinking="thinking" in (capability or []), + capability=capability, + deep_thinking=deep_thinking, + thinking_budget_tokens=thinking_budget_tokens, + json_output=json_output, extra_params={ "temperature": temperature, "max_tokens": max_tokens, @@ -112,6 +117,9 @@ class LangChainAgent: ) self.llm = RedBearLLM(model_config, type=ModelType.CHAT) + # 从经过校验的 config 读取实际生效的能力开关 + self.deep_thinking = model_config.deep_thinking + self.json_output = model_config.json_output # 获取底层模型用于真正的流式调用 self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm @@ -237,9 +245,7 @@ class LangChainAgent: Returns: List[BaseMessage]: 消息列表 """ - messages:list = [SystemMessage(content=self.system_prompt)] - - # 添加系统提示词 + messages: list = [] # 添加历史消息 if history: diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index 342405b8..91d6bd8a 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -96,6 +96,38 @@ def require_api_key( resource_id=api_key_obj.resource_id, ) + # ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)────────── + try: + from app.models.workspace_model import Workspace + from premium.platform_admin.package_plan_service import TenantSubscriptionService + + workspace = db.query(Workspace).filter( + Workspace.id == api_key_obj.workspace_id + ).first() + if workspace: + quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id) + tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None + if tenant_qps_limit: + rate_limiter = RateLimiterService() + tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit( + workspace.tenant_id, tenant_qps_limit + ) + if not tenant_ok: + raise RateLimitException( + "租户 API 调用速率超限", + BizCode.API_KEY_QPS_LIMIT_EXCEEDED, + rate_headers={ + "X-RateLimit-Tenant-Limit": str(tenant_info["limit"]), + "X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]), + "X-RateLimit-Tenant-Reset": str(tenant_info["reset"]), + } + ) + except RateLimitException: + raise + except Exception as e: + logger.warning(f"Tenant 限速检查异常,跳过: {e}") + # ───────────────────────────────────────────────────────────── + rate_limiter = RateLimiterService() is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj) if not is_allowed: diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index bae4643e..3b0ea1ee 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -14,6 +14,7 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ memory_summary_generation @@ -191,15 +192,37 @@ async def write( if success: logger.info("Successfully saved all data to Neo4j") - # 使用 Celery 异步任务触发聚类(不阻塞主流程) if all_entity_nodes: + end_user_id = all_entity_nodes[0].end_user_id + + # Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体 + try: + from app.repositories.end_user_info_repository import EndUserInfoRepository + if end_user_id: + with get_db_context() as db_session: + info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id)) + pg_aliases = info.aliases if info and info.aliases else [] + if info is not None: + # 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码 + placeholder_names = list(_USER_PLACEHOLDER_NAMES) + await neo4j_connector.execute_query( + """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names + SET e.aliases = $aliases + """, + end_user_id=end_user_id, aliases=pg_aliases, + placeholder_names=placeholder_names, + ) + logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}") + except Exception as sync_err: + logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}") + + # 使用 Celery 异步任务触发聚类(不阻塞主流程) try: from app.tasks import run_incremental_clustering - end_user_id = all_entity_nodes[0].end_user_id new_entity_ids = [e.id for e in all_entity_nodes] - - # 异步提交 Celery 任务 task = run_incremental_clustering.apply_async( kwargs={ "end_user_id": end_user_id, @@ -207,7 +230,6 @@ async def write( "llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None, "embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, }, - # 设置任务优先级(低优先级,不影响主业务) priority=3, ) logger.info( @@ -215,7 +237,6 @@ async def write( f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}" ) except Exception as e: - # 聚类任务提交失败不影响主流程 logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True) break diff --git a/api/app/core/memory/models/__init__.py b/api/app/core/memory/models/__init__.py index eed8e8c4..2a34159b 100644 --- a/api/app/core/memory/models/__init__.py +++ b/api/app/core/memory/models/__init__.py @@ -61,9 +61,9 @@ from app.core.memory.models.triplet_models import ( # User metadata models from app.core.memory.models.metadata_models import ( UserMetadata, - UserMetadataBehavioralHints, UserMetadataProfile, MetadataExtractionResponse, + MetadataFieldChange, ) # Ontology scenario models (LLM extracted from scenarios) @@ -133,9 +133,9 @@ __all__ = [ "Triplet", "TripletExtractionResponse", "UserMetadata", - "UserMetadataBehavioralHints", "UserMetadataProfile", "MetadataExtractionResponse", + "MetadataFieldChange", # Ontology models "OntologyClass", "OntologyExtractionResponse", diff --git a/api/app/core/memory/models/metadata_models.py b/api/app/core/memory/models/metadata_models.py index 55c2359e..e12c3d97 100644 --- a/api/app/core/memory/models/metadata_models.py +++ b/api/app/core/memory/models/metadata_models.py @@ -4,7 +4,7 @@ Independent from triplet_models.py - these models are used by the standalone metadata extraction pipeline (post-dedup async Celery task). """ -from typing import List +from typing import List, Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -13,8 +13,8 @@ class UserMetadataProfile(BaseModel): """用户画像信息""" model_config = ConfigDict(extra="ignore") - role: str = Field(default="", description="用户职业或角色") - domain: str = Field(default="", description="用户所在领域") + role: List[str] = Field(default_factory=list, description="用户职业或角色") + domain: List[str] = Field(default_factory=list, description="用户所在领域") expertise: List[str] = Field( default_factory=list, description="用户擅长的技能或工具" ) @@ -23,31 +23,37 @@ class UserMetadataProfile(BaseModel): ) -class UserMetadataBehavioralHints(BaseModel): - """行为偏好""" - - model_config = ConfigDict(extra="ignore") - learning_stage: str = Field(default="", description="学习阶段") - preferred_depth: str = Field(default="", description="偏好深度") - tone_preference: str = Field(default="", description="语气偏好") - - class UserMetadata(BaseModel): """用户元数据顶层结构""" model_config = ConfigDict(extra="ignore") profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile) - behavioral_hints: UserMetadataBehavioralHints = Field( - default_factory=UserMetadataBehavioralHints + + +class MetadataFieldChange(BaseModel): + """单个元数据字段的变更操作""" + + model_config = ConfigDict(extra="ignore") + field_path: str = Field( + description="字段路径,用点号分隔,如 'profile.role'、'profile.expertise'" + ) + action: Literal["set", "remove"] = Field( + description="操作类型:'set' 表示新增或修改,'remove' 表示移除" + ) + value: Optional[str] = Field( + default=None, + description="字段的新值(action='set' 时必填)。标量字段直接填值,列表字段填单个要新增的元素" ) - knowledge_tags: List[str] = Field(default_factory=list, description="知识标签") class MetadataExtractionResponse(BaseModel): - """元数据提取 LLM 响应结构""" + """元数据提取 LLM 响应结构(增量模式)""" model_config = ConfigDict(extra="ignore") - user_metadata: UserMetadata = Field(default_factory=UserMetadata) + metadata_changes: List[MetadataFieldChange] = Field( + default_factory=list, + description="元数据的增量变更列表,每项描述一个字段的新增、修改或移除操作", + ) aliases_to_add: List[str] = Field( default_factory=list, description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)", diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 7e0976fe..715f190c 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -82,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): canonical.connect_strength = next(iter(pair)) # 别名合并(去重保序,使用标准化工具) + # 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改 try: canonical_name = (getattr(canonical, "name", "") or "").strip() - incoming_name = (getattr(ent, "name", "") or "").strip() - - # 收集所有需要合并的别名 - all_aliases = [] - - # 1. 添加canonical现有的别名 - existing = getattr(canonical, "aliases", []) or [] - all_aliases.extend(existing) - - # 2. 添加incoming实体的名称(如果不同于canonical的名称) - if incoming_name and incoming_name != canonical_name: - all_aliases.append(incoming_name) - - # 3. 添加incoming实体的所有别名 - incoming = getattr(ent, "aliases", []) or [] - all_aliases.extend(incoming) - - # 4. 标准化并去重(优先使用alias_utils工具函数) - try: - from app.core.memory.utils.alias_utils import normalize_aliases - canonical.aliases = normalize_aliases(canonical_name, all_aliases) - except Exception: - # 如果导入失败,使用增强的去重逻辑 - seen_normalized = set() - unique_aliases = [] + if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES: + incoming_name = (getattr(ent, "name", "") or "").strip() - for alias in all_aliases: - if not alias: - continue - - alias_stripped = str(alias).strip() - if not alias_stripped or alias_stripped == canonical_name: - continue - - # 标准化:转小写用于去重判断 - alias_normalized = alias_stripped.lower() - - if alias_normalized not in seen_normalized: - seen_normalized.add(alias_normalized) - unique_aliases.append(alias_stripped) + # 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体 + all_aliases = list(getattr(canonical, "aliases", []) or []) + if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES: + all_aliases.append(incoming_name) + all_aliases.extend( + a for a in (getattr(ent, "aliases", []) or []) + if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES + ) - # 排序并赋值 - canonical.aliases = sorted(unique_aliases) + try: + from app.core.memory.utils.alias_utils import normalize_aliases + canonical.aliases = normalize_aliases(canonical_name, all_aliases) + except Exception: + seen_normalized = set() + unique_aliases = [] + for alias in all_aliases: + if not alias: + continue + alias_stripped = str(alias).strip() + if not alias_stripped or alias_stripped == canonical_name: + continue + alias_normalized = alias_stripped.lower() + if alias_normalized not in seen_normalized: + seen_normalized.add(alias_normalized) + unique_aliases.append(alias_stripped) + canonical.aliases = sorted(unique_aliases) except Exception: pass @@ -733,66 +720,37 @@ def fuzzy_match( def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode): - """ 模糊匹配中的实体合并。 + """模糊匹配中的实体合并(别名部分)。 - 合并策略: - 1. 保留canonical的主名称不变 - 2. 将losing的主名称添加为alias(如果不同) - 3. 合并两个实体的所有aliases - 4. 自动去重(case-insensitive)并排序 - - Args: - canonical: 规范实体(保留) - losing: 被合并实体(删除) - - Note: - 使用alias_utils.normalize_aliases进行标准化去重 + 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。 """ - # 获取规范实体的名称 canonical_name = (getattr(canonical, "name", "") or "").strip() + if canonical_name.lower() in _USER_PLACEHOLDER_NAMES: + return + losing_name = (getattr(losing, "name", "") or "").strip() - # 收集所有需要合并的别名 - all_aliases = [] - - # 1. 添加canonical现有的别名 - current_aliases = getattr(canonical, "aliases", []) or [] - all_aliases.extend(current_aliases) - - # 2. 添加losing实体的名称(如果不同于canonical的名称) + all_aliases = list(getattr(canonical, "aliases", []) or []) if losing_name and losing_name != canonical_name: all_aliases.append(losing_name) + all_aliases.extend(getattr(losing, "aliases", []) or []) - # 3. 添加losing实体的所有别名 - losing_aliases = getattr(losing, "aliases", []) or [] - all_aliases.extend(losing_aliases) - - # 4. 标准化并去重(使用标准化后的字符串进行去重) try: from app.core.memory.utils.alias_utils import normalize_aliases canonical.aliases = normalize_aliases(canonical_name, all_aliases) except Exception: - # 如果导入失败,使用增强的去重逻辑 - # 使用标准化后的字符串作为key进行去重 seen_normalized = set() unique_aliases = [] - for alias in all_aliases: if not alias: continue - alias_stripped = str(alias).strip() if not alias_stripped or alias_stripped == canonical_name: continue - - # 标准化:转小写用于去重判断 alias_normalized = alias_stripped.lower() - if alias_normalized not in seen_normalized: seen_normalized.add(alias_normalized) unique_aliases.append(alias_stripped) - - # 排序并赋值 canonical.aliases = sorted(unique_aliases) # ========== 主循环:遍历所有实体对进行模糊匹配 ========== diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 5636dcb5..75fc87d2 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1391,18 +1391,18 @@ class ExtractionOrchestrator: """ 将本轮提取的用户别名同步到 end_user 和 end_user_info 表。 - 注意:此方法在 Neo4j 写入之前调用,因此不能依赖 Neo4j 作为别名的权威数据源。 - 改为直接使用内存中去重后的 entity_nodes 的 aliases,与 PgSQL 已有的 aliases 合并。 + PgSQL end_user_info.aliases 是用户别名的唯一权威源。 + 此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL, + 不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。 策略: - 1. 从内存中的 entity_nodes 提取本轮用户别名(current_aliases) - 2. 从去重后的 entity_nodes 中提取完整别名(含 Neo4j 二层去重合并的历史别名) - 3. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases) - 4. 合并 db_aliases + deduped_aliases + current_aliases,去重保序 - 5. 写回 PgSQL + 1. 从本轮对话原始发言中提取用户别名(current_aliases) + 2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases) + 3. 合并 db_aliases + current_aliases,去重保序 + 4. 写回 PgSQL Args: - entity_nodes: 去重后的实体节点列表(内存中,含二层去重合并结果) + entity_nodes: 去重后的实体节点列表(内存中) dialog_data_list: 对话数据列表 """ try: @@ -1418,11 +1418,6 @@ class ExtractionOrchestrator: # 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序) current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list) - # 1.5 从去重后的 entity_nodes 中提取完整别名 - # 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 中, - # 这里提取出来确保 PgSQL 与 Neo4j 的别名保持同步 - deduped_aliases = self._extract_deduped_entity_aliases(entity_nodes) - # 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源 # (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中) neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id) @@ -1434,19 +1429,12 @@ class ExtractionOrchestrator: ] if len(current_aliases) < before_count: logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名") - # 同样过滤 deduped_aliases - deduped_aliases = [ - a for a in deduped_aliases - if a.strip().lower() not in neo4j_assistant_aliases - ] - if not current_aliases and not deduped_aliases: + if not current_aliases: logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}") return logger.info(f"本轮对话提取的 aliases: {current_aliases}") - if deduped_aliases: - logger.info(f"去重后实体的完整 aliases(含历史): {deduped_aliases}") # 2. 同步到数据库 end_user_uuid = uuid.UUID(end_user_id) @@ -1457,21 +1445,15 @@ class ExtractionOrchestrator: logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录") return - # 3. 从 PgSQL 读取已有 aliases 并与本轮合并 + # 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并 info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) db_aliases = (info.aliases if info and info.aliases else []) # 过滤掉占位名称 db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES] - # 合并:已有 + 去重后完整别名 + 本轮新增,去重保序 + # 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名) merged_aliases = list(db_aliases) seen_lower = {a.strip().lower() for a in merged_aliases} - # 先合并去重后实体的完整别名(含 Neo4j 历史别名) - for alias in deduped_aliases: - if alias.strip().lower() not in seen_lower: - merged_aliases.append(alias) - seen_lower.add(alias.strip().lower()) - # 再合并本轮新提取的别名 for alias in current_aliases: if alias.strip().lower() not in seen_lower: merged_aliases.append(alias) @@ -1505,9 +1487,7 @@ class ExtractionOrchestrator: info.aliases = merged_aliases logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}") else: - first_alias = current_aliases[0].strip() if current_aliases else ( - deduped_aliases[0].strip() if deduped_aliases else "" - ) + first_alias = current_aliases[0].strip() if current_aliases else "" # 确保 first_alias 不是占位名称 if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES: db.add(EndUserInfo( diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py index 19f1e533..29f4e85b 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py @@ -118,7 +118,7 @@ class MetadataExtractor: existing_aliases: Optional[List[str]] = None, ) -> Optional[tuple]: """ - 对筛选后的 statement 列表调用 LLM 提取元数据和用户别名。 + 对筛选后的 statement 列表调用 LLM 提取元数据增量变更和用户别名。 Args: statements: 用户发言的 statement 文本列表 @@ -126,7 +126,8 @@ class MetadataExtractor: existing_aliases: 数据库已有的用户别名列表(可选) Returns: - (UserMetadata, List[str], List[str]) tuple: (metadata, aliases_to_add, aliases_to_remove) on success, None on failure + (List[MetadataFieldChange], List[str], List[str]) tuple: + (metadata_changes, aliases_to_add, aliases_to_remove) on success, None on failure """ if not statements: return None @@ -160,12 +161,12 @@ class MetadataExtractor: ) if response: - metadata = response.user_metadata if response.user_metadata else None + changes = response.metadata_changes if response.metadata_changes else [] to_add = response.aliases_to_add if response.aliases_to_add else [] to_remove = ( response.aliases_to_remove if response.aliases_to_remove else [] ) - return metadata, to_add, to_remove + return changes, to_add, to_remove logger.warning("LLM 返回的响应为空") return None diff --git a/api/app/core/memory/storage_services/search/keyword_search.py b/api/app/core/memory/storage_services/search/keyword_search.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 index 5d019b12..1c32d369 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 @@ -1,5 +1,5 @@ ===Task=== -Extract user metadata from the following conversation statements spoken by the user. +Extract user metadata changes from the following conversation statements spoken by the user. {% if language == "zh" %} **"三度原则"判断标准:** @@ -10,28 +10,36 @@ Extract user metadata from the following conversation statements spoken by the u **提取规则:** - **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息 - 仅提取文本中明确提到的信息,不要推测 -- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象 - **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值) +**增量模式(重要):** +你只需要输出**本次对话引起的变更操作**,不要输出完整的元数据。每个变更是一个对象,包含: +- `field_path`:字段路径,用点号分隔(如 `profile.role`、`profile.expertise`) +- `action`:操作类型 + * `set`:新增或修改一个字段的值 + * `remove`:移除一个字段的值 +- `value`:字段的新值(`action="set"` 时必填,`action="remove"` 时填要移除的元素值) + * 所有字段均为列表类型,每个元素一条变更记录 + +**判断规则:** +- 用户提到新信息 → `action="set"`,填入新值 +- 用户明确否定已有信息(如"我不再做老师了"、"我已经不学Python了")→ `action="remove"`,`value` 填要移除的元素值 +- 如果本次对话没有任何可提取的变更,返回空的 `metadata_changes` 数组 `[]` +- **不要为未被提及的字段生成任何变更操作** + {% if existing_metadata %} -**重要:合并已有元数据** -下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**: -- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息 -- 如果用户提到了新信息,**添加**到对应字段中 -- 如果已有信息未被用户否定,**保留**在输出中 -- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值 -- 最终输出应该是完整的、合并后的元数据,不是增量 +**已有元数据(仅供参考,用于判断是否需要变更):** +请对比已有数据和用户最新发言,只输出差异部分的变更操作。 +- 如果用户说的信息和已有数据一致,不需要输出变更 +- 如果用户否定了已有数据中的某个值,输出 `remove` 操作 +- 如果用户提到了新信息,输出 `set` 操作 {% endif %} **字段说明:** -- profile.role:用户的职业或角色,如 教师、医生、后端工程师 -- profile.domain:用户所在领域,如 教育、医疗、软件开发 -- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理 -- profile.interests:用户主动表达兴趣的话题或领域标签 -- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级) -- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨) -- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨) -- knowledge_tags:用户涉及的知识领域标签 +- profile.role:用户的职业或角色(列表),如 教师、医生、后端工程师,一个人可以有多个角色 +- profile.domain:用户所在领域(列表),如 教育、医疗、软件开发,一个人可以涉及多个领域 +- profile.expertise:用户擅长的技能或工具(列表),如 Python、心理咨询、高中物理 +- profile.interests:用户主动表达兴趣的话题或领域标签(列表) **用户别名变更(增量模式):** - **aliases_to_add**:本次新发现的用户别名,包括: @@ -43,7 +51,6 @@ Extract user metadata from the following conversation statements spoken by the u - **aliases_to_remove**:用户明确否认的别名,包括: * 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组 * **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名 - * 例如:用户说"我不叫陈小刀了" → 只移除"陈小刀",不要移除"陈哥"、"老陈"等未被提及的别名 * 如果没有要移除的别名,返回空数组 `[]` {% if existing_aliases %} - 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复) @@ -57,28 +64,36 @@ Extract user metadata from the following conversation statements spoken by the u **Extraction rules:** - **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user - Only extract information explicitly mentioned in the text, do not speculate -- If no user profile information can be extracted, return an empty user_metadata object - **Output language must match the input text language** +**Incremental mode (important):** +You should only output **the change operations caused by this conversation**, not the complete metadata. Each change is an object containing: +- `field_path`: Field path separated by dots (e.g. `profile.role`, `profile.expertise`) +- `action`: Operation type + * `set`: Add or update a field value + * `remove`: Remove a field value +- `value`: The new value for the field (required when `action="set"`, for `action="remove"` fill in the element value to remove) + * All fields are list types, one change record per element + +**Decision rules:** +- User mentions new information → `action="set"`, fill in the new value +- User explicitly negates existing info (e.g. "I'm no longer a teacher", "I stopped learning Python") → `action="remove"`, `value` is the element to remove +- If this conversation has no extractable changes, return an empty `metadata_changes` array `[]` +- **Do NOT generate any change operations for fields not mentioned in the conversation** + {% if existing_metadata %} -**Important: Merge with existing metadata** -Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**: -- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output -- If the user mentions new info, **add** it to the corresponding field -- If existing info is not negated by the user, **keep** it in the output -- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing -- The final output should be the complete, merged metadata — not an incremental update +**Existing metadata (for reference only, to determine if changes are needed):** +Compare existing data with the user's latest statements, and only output change operations for the differences. +- If the user's statement matches existing data, no change is needed +- If the user negates a value in existing data, output a `remove` operation +- If the user mentions new information, output a `set` operation {% endif %} **Field descriptions:** -- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer -- profile.domain: User's domain, e.g. education, healthcare, software development -- profile.expertise: User's skills or tools (general, not limited to programming) -- profile.interests: Topics or domain tags the user actively expressed interest in -- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced) -- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive) -- behavioral_hints.tone_preference: Tone preference (casual/professional/academic) -- knowledge_tags: Knowledge domain tags related to the user +- profile.role: User's occupation or role (list), e.g. teacher, doctor, software engineer. A person can have multiple roles +- profile.domain: User's domain (list), e.g. education, healthcare, software development. A person can span multiple domains +- profile.expertise: User's skills or tools (list), e.g. Python, counseling, physics +- profile.interests: Topics or domain tags the user actively expressed interest in (list) **User alias changes (incremental mode):** - **aliases_to_add**: Newly discovered user aliases from this conversation, including: @@ -90,7 +105,6 @@ Existing user metadata from the database is provided below. Combine with the use - **aliases_to_remove**: Aliases the user explicitly denies, including: * User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array * **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases - * Example: User says "I'm not called John anymore" → only remove "John", do NOT remove "Johnny", "J" or other related aliases not mentioned * If no aliases to remove, return empty array `[]` {% if existing_aliases %} - Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output) @@ -113,20 +127,11 @@ Existing user metadata from the database is provided below. Combine with the use Return a JSON object with the following structure: ```json { - "user_metadata": { - "profile": { - "role": "", - "domain": "", - "expertise": [], - "interests": [] - }, - "behavioral_hints": { - "learning_stage": "", - "preferred_depth": "", - "tone_preference": "" - }, - "knowledge_tags": [] - }, + "metadata_changes": [ + {"field_path": "profile.role", "action": "set", "value": "后端工程师"}, + {"field_path": "profile.expertise", "action": "set", "value": "Python"}, + {"field_path": "profile.expertise", "action": "remove", "value": "Java"} + ], "aliases_to_add": [], "aliases_to_remove": [] } diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 1de4b120..86ac5fe0 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Dict, List, Optional, TypeVar from langchain_aws import ChatBedrock from langchain_community.chat_models import ChatTongyi @@ -9,12 +9,12 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLLM from langchain_ollama import OllamaLLM from langchain_openai import ChatOpenAI, OpenAI -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.models.models_model import ModelProvider, ModelType -from app.core.models.volcano_chat import VolcanoChatOpenAI +from app.core.models.compatible_chat import CompatibleChatOpenAI T = TypeVar("T") @@ -25,10 +25,11 @@ class RedBearModelConfig(BaseModel): provider: str api_key: str base_url: Optional[str] = None + capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关 is_omni: bool = False # 是否为 Omni 模型 deep_thinking: bool = False # 是否启用深度思考模式 thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算 - support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking) + json_output: bool = False # 是否强制 JSON 输出 # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 @@ -36,6 +37,23 @@ class RedBearModelConfig(BaseModel): concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + @model_validator(mode="after") + def _resolve_capabilities(self) -> "RedBearModelConfig": + from app.core.logging_config import get_business_logger + logger = get_business_logger() + if self.deep_thinking and "thinking" not in self.capability: + logger.warning( + f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking" + ) + self.deep_thinking = False + self.thinking_budget_tokens = None + if self.json_output and "json_output" not in self.capability: + logger.warning( + f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output" + ) + self.json_output = False + return self + class RedBearModelFactory: """模型工厂类""" @@ -74,18 +92,19 @@ class RedBearModelFactory: is_streaming = bool(config.extra_params.get("streaming")) if is_streaming: params["stream_usage"] = True - # 只有支持 thinking 的模型才传 enable_thinking - if config.support_thinking: - model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {}) - if is_streaming: - model_kwargs["enable_thinking"] = config.deep_thinking - if config.deep_thinking: - model_kwargs["incremental_output"] = True - if config.thinking_budget_tokens: - model_kwargs["thinking_budget"] = config.thinking_budget_tokens - else: - model_kwargs["enable_thinking"] = False - params["model_kwargs"] = model_kwargs + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: + extra_body = params.setdefault("extra_body", {}) + if config.deep_thinking: + extra_body["enable_thinking"] = False + if is_streaming: + extra_body["enable_thinking"] = True + if config.thinking_budget_tokens: + extra_body["thinking_budget"] = config.thinking_budget_tokens + # JSON 输出模式 + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + model_kwargs["response_format"] = {"type": "json_object"} return params if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: @@ -108,26 +127,31 @@ class RedBearModelFactory: **config.extra_params } # 流式模式下启用 stream_usage 以获取 token 统计 - if config.extra_params.get("streaming"): - params["stream_usage"] = True - # 深度思考模式 is_streaming = bool(config.extra_params.get("streaming")) - if is_streaming and not config.is_omni: + if is_streaming: + params["stream_usage"] = True + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: + # VOLCANO 深度思考仅流式支持 if provider == ModelProvider.VOLCANO: - # 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数 - thinking_config: Dict[str, Any] = { - "type": "enabled" if config.deep_thinking else "disabled" - } + thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"} if config.deep_thinking and config.thinking_budget_tokens: thinking_config["budget_tokens"] = config.thinking_budget_tokens params["extra_body"] = {"thinking": thinking_config} else: - # 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略 - model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {}) - model_kwargs["enable_thinking"] = config.deep_thinking - if config.deep_thinking and config.thinking_budget_tokens: - model_kwargs["thinking_budget"] = config.thinking_budget_tokens - params["model_kwargs"] = model_kwargs + extra_body = params.setdefault("extra_body", {}) + if config.deep_thinking: + extra_body["enable_thinking"] = False + if is_streaming: + extra_body["enable_thinking"] = True + if config.thinking_budget_tokens: + extra_body["thinking_budget"] = config.thinking_budget_tokens + # JSON 输出模式 + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + # VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现 + if provider != ModelProvider.VOLCANO: + model_kwargs["response_format"] = {"type": "json_object"} return params elif provider == ModelProvider.DASHSCOPE: params = { @@ -136,19 +160,20 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } - # 只有支持 thinking 的模型才传 enable_thinking - if config.support_thinking: + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: is_streaming = bool(config.extra_params.get("streaming")) - model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {}) - if is_streaming: - model_kwargs["enable_thinking"] = config.deep_thinking - if config.deep_thinking: - model_kwargs["incremental_output"] = True - if config.thinking_budget_tokens: - model_kwargs["thinking_budget"] = config.thinking_budget_tokens - else: + model_kwargs = params.setdefault("model_kwargs", {}) + if config.deep_thinking: model_kwargs["enable_thinking"] = False - params["model_kwargs"] = model_kwargs + if is_streaming: + model_kwargs["enable_thinking"] = True + model_kwargs["incremental_output"] = True + if config.thinking_budget_tokens: + model_kwargs["thinking_budget"] = config.thinking_budget_tokens + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + model_kwargs["response_format"] = {"type": "json_object"} return params elif provider == ModelProvider.BEDROCK: # Bedrock 使用 AWS 凭证 @@ -195,6 +220,10 @@ class RedBearModelFactory: params["additional_model_request_fields"] = { "thinking": {"type": "enabled", "budget_tokens": budget} } + # JSON 输出模式 + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + model_kwargs["response_format"] = {"type": "json_object"} return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) @@ -223,18 +252,19 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - # dashscope 的 omni 模型使用 OpenAI 兼容模式 + # dashscope的omni模型 和 volcano模型使用 if provider == ModelProvider.DASHSCOPE and config.is_omni: - return ChatOpenAI + return CompatibleChatOpenAI if provider == ModelProvider.VOLCANO: - return VolcanoChatOpenAI + return CompatibleChatOpenAI if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - if type == ModelType.LLM: - return OpenAI - elif type == ModelType.CHAT: - return ChatOpenAI - else: - raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) + return CompatibleChatOpenAI + # if type == ModelType.LLM: + # return OpenAI + # elif type == ModelType.CHAT: + # return CompatibleChatOpenAI + # else: + # raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) elif provider == ModelProvider.DASHSCOPE: return ChatTongyi elif provider == ModelProvider.OLLAMA: diff --git a/api/app/core/models/volcano_chat.py b/api/app/core/models/compatible_chat.py similarity index 63% rename from api/app/core/models/volcano_chat.py rename to api/app/core/models/compatible_chat.py index d9a51d13..218c46e0 100644 --- a/api/app/core/models/volcano_chat.py +++ b/api/app/core/models/compatible_chat.py @@ -8,12 +8,33 @@ from __future__ import annotations from typing import Any, Optional, Union +from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_openai import ChatOpenAI -class VolcanoChatOpenAI(ChatOpenAI): - """火山引擎 Chat 模型,支持深度思考内容(reasoning_content)的流式和非流式透传。""" +class CompatibleChatOpenAI(ChatOpenAI): + """火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。 + + 同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream() + 导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format, + 让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。 + """ + + def _get_request_payload( + self, + input_: list[BaseMessage], + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> dict: + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + # 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream() + # 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。 + # 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。 + if payload.get("tools") and "response_format" in payload: + payload.pop("response_format") + return payload def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult: result = super()._create_chat_result(response, generation_info) diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index 5b3a2f64..f96dba15 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -6,7 +6,8 @@ models: description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -20,6 +21,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -38,6 +40,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -54,7 +57,8 @@ models: description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -72,6 +76,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -87,7 +92,8 @@ models: description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -101,7 +107,8 @@ models: description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -115,7 +122,8 @@ models: description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -130,7 +138,8 @@ models: description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index d9e6a00f..9b45f107 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -8,6 +8,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -22,6 +23,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -36,6 +38,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -48,7 +51,8 @@ models: description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -61,7 +65,8 @@ models: description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -74,7 +79,8 @@ models: description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -87,7 +93,8 @@ models: description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -100,7 +107,8 @@ models: description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -115,7 +123,8 @@ models: description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -133,6 +142,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -150,6 +160,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -180,6 +191,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -210,7 +222,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -376,6 +388,7 @@ models: capability: - vision - video + - json_output is_omni: false tags: - 大语言模型 @@ -448,6 +461,7 @@ models: capability: - vision - video + - json_output is_omni: false tags: - 大语言模型 @@ -466,6 +480,7 @@ models: capability: - vision - video + - json_output is_omni: false tags: - 大语言模型 @@ -481,7 +496,8 @@ models: description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -498,6 +514,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -513,7 +530,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -530,6 +547,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -546,6 +564,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -561,7 +580,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -578,6 +597,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -594,6 +614,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -610,6 +631,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -626,6 +648,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -641,7 +664,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -656,7 +679,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -672,6 +695,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -687,6 +711,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -702,6 +727,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -719,6 +745,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -736,6 +763,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -752,6 +780,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -768,7 +797,7 @@ models: is_deprecated: false is_official: true capability: - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -785,6 +814,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -803,6 +833,8 @@ models: - vision - video - audio + - thinking + - json_output is_omni: true tags: - 大语言模型 @@ -822,7 +854,7 @@ models: capability: - vision - video - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -844,6 +876,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -864,7 +897,7 @@ models: capability: - vision - video - - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -886,6 +919,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -907,6 +941,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -928,6 +963,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -947,6 +983,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -964,6 +1001,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -979,6 +1017,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -994,6 +1033,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 08b81008..1c0a0b2d 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -10,6 +10,7 @@ models: - vision - audio - video + - json_output is_omni: true tags: - 大语言模型 @@ -27,7 +28,8 @@ models: description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -42,7 +44,8 @@ models: description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -57,7 +60,8 @@ models: description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -84,7 +88,8 @@ models: description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -99,7 +104,8 @@ models: description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -114,7 +120,8 @@ models: description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -131,6 +138,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -146,7 +154,8 @@ models: description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -163,6 +172,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -194,6 +204,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -213,6 +224,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -231,6 +243,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -248,6 +261,7 @@ models: is_official: true capability: - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -266,6 +280,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -284,6 +299,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -302,6 +318,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -321,6 +338,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -340,6 +358,7 @@ models: capability: - vision - thinking + - json_output is_omni: false tags: - 大语言模型 diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml index c86d41ac..6658c2f9 100644 --- a/api/app/core/models/scripts/volcano_models.yaml +++ b/api/app/core/models/scripts/volcano_models.yaml @@ -11,6 +11,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -26,6 +27,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -41,6 +43,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -56,6 +59,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -72,6 +76,7 @@ models: capability: - vision - video + - json_output is_omni: false tags: - 大语言模型 @@ -87,6 +92,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -102,6 +108,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -117,6 +124,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -132,6 +140,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -148,6 +157,7 @@ models: - vision - video - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -175,7 +185,8 @@ models: description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -187,7 +198,8 @@ models: description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 diff --git a/api/app/core/quota_manager.py b/api/app/core/quota_manager.py new file mode 100644 index 00000000..0e0053a0 --- /dev/null +++ b/api/app/core/quota_manager.py @@ -0,0 +1,485 @@ +""" +统一配额管理器 - 社区版和 SaaS 版共用 + +配额来源策略: +1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版) +2. 降级到 default_free_plan.py 配置文件(社区版兜底) +""" +import asyncio +import time +from functools import wraps +from typing import Optional, Callable, Dict, Any +from uuid import UUID + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.core.logging_config import get_auth_logger +from app.i18n.exceptions import QuotaExceededError + +logger = get_auth_logger() + + +def _get_user_from_kwargs(kwargs: dict): + """从 kwargs 中获取 user 对象""" + for key in ["user", "current_user"]: + if key in kwargs: + return kwargs[key] + return None + + +def _get_tenant_id_from_kwargs(db: Session, kwargs: dict): + """从 kwargs 中获取 tenant_id""" + user = _get_user_from_kwargs(kwargs) + if user and hasattr(user, 'tenant_id'): + return user.tenant_id + + workspace_id = kwargs.get("workspace_id") + if workspace_id: + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if workspace: + return workspace.tenant_id + + api_key_auth = kwargs.get("api_key_auth") + if api_key_auth and hasattr(api_key_auth, 'workspace_id'): + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first() + if workspace: + return workspace.tenant_id + + data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload") + if data and hasattr(data, "workspace_id"): + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first() + if workspace: + return workspace.tenant_id + + share_data = kwargs.get("share_data") + if share_data and hasattr(share_data, 'share_token'): + from app.models.workspace_model import Workspace + from app.models.app_model import App + share_token = share_data.share_token + from app.models.release_share_model import ReleaseShare + share_record = db.query(ReleaseShare).filter(ReleaseShare.share_token == share_token).first() + if share_record: + app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first() + if app: + return app.workspace.tenant_id + + return None + + +def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]: + """ + 获取租户的配额配置 + + 优先级: + 1. premium 模块的 tenant_subscriptions(SaaS 版) + 2. default_free_plan.py 配置文件(社区版兜底) + """ + # 尝试从 premium 模块获取 + try: + from premium.platform_admin.package_plan_service import TenantSubscriptionService + quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id) + if quota_config: + logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置") + return quota_config + except (ModuleNotFoundError, ImportError, Exception) as e: + logger.debug(f"无法从 premium 模块获取配额配置: {e}") + + # 降级到配置文件 + try: + from app.config.default_free_plan import DEFAULT_FREE_PLAN + logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}") + return DEFAULT_FREE_PLAN.get("quotas") + except Exception as e: + logger.error(f"无法从配置文件获取配额: {e}") + return None + + +class QuotaUsageRepository: + """配额使用量数据访问层""" + + def __init__(self, db: Session): + self.db = db + + def count_workspaces(self, tenant_id: UUID) -> int: + from app.models.workspace_model import Workspace + return self.db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active.is_(True) + ).count() + + def count_apps(self, tenant_id: UUID) -> int: + from app.models.app_model import App + from app.models.workspace_model import Workspace + return self.db.query(App).join( + Workspace, App.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id, + App.is_active.is_(True) + ).count() + + def count_skills(self, tenant_id: UUID) -> int: + from app.models.skill_model import Skill + return self.db.query(Skill).filter( + Skill.tenant_id == tenant_id, + Skill.is_active.is_(True) + ).count() + + def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float: + from app.models.document_model import Document + from app.models.knowledge_model import Knowledge + from app.models.workspace_model import Workspace + result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join( + Knowledge, Document.kb_id == Knowledge.id + ).join( + Workspace, Knowledge.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id, + Document.status == 1, + ).scalar() + return float(result) / (1024 ** 3) if result else 0.0 + + def count_memory_engines(self, tenant_id: UUID) -> int: + from app.models.memory_config_model import MemoryConfig + from app.models.workspace_model import Workspace + return self.db.query(MemoryConfig).join( + Workspace, MemoryConfig.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def count_end_users(self, tenant_id: UUID) -> int: + from app.models.end_user_model import EndUser + from app.models.workspace_model import Workspace + return self.db.query(EndUser).join( + Workspace, EndUser.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def count_models(self, tenant_id: UUID) -> int: + from app.models.models_model import ModelConfig + return self.db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_active == True + ).count() + + def count_ontology_projects(self, tenant_id: UUID) -> int: + from app.models.ontology_scene import OntologyScene + from app.models.workspace_model import Workspace + return self.db.query(OntologyScene).join( + Workspace, OntologyScene.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str): + """按配额类型分发,返回当前使用量""" + dispatch = { + "workspace_quota": self.count_workspaces, + "app_quota": self.count_apps, + "skill_quota": self.count_skills, + "knowledge_capacity_quota": self.sum_knowledge_capacity_gb, + "memory_engine_quota": self.count_memory_engines, + "end_user_quota": self.count_end_users, + "model_quota": self.count_models, + "ontology_project_quota": self.count_ontology_projects, + } + fn = dispatch.get(quota_type) + return fn(tenant_id) if fn else 0 + + +def _check_quota( + db: Session, + tenant_id: UUID, + quota_type: str, + resource_name: str, + usage_func: Optional[Callable] = None, +) -> None: + """核心配额检查逻辑:对比使用量和配额限制""" + try: + quota_config = _get_quota_config(db, tenant_id) + if not quota_config: + logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查") + return + + quota_limit = quota_config.get(quota_type) + if quota_limit is None: + logger.warning(f"配额配置未包含 {quota_type},跳过配额检查") + return + + if usage_func: + current_usage = usage_func(db, tenant_id) + else: + current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type) + + if current_usage >= quota_limit: + logger.warning( + f"配额不足: tenant={tenant_id}, type={quota_type}, " + f"usage={current_usage}, limit={quota_limit}" + ) + raise QuotaExceededError( + resource=resource_name, + current_usage=current_usage, + quota_limit=quota_limit, + ) + + logger.debug( + f"配额检查通过: tenant={tenant_id}, type={quota_type}, " + f"usage={current_usage}, limit={quota_limit}" + ) + + except QuotaExceededError: + raise + except Exception as e: + logger.error( + f"配额检查异常: tenant={tenant_id}, type={quota_type}, " + f"error_type={type(e).__name__}, error={str(e)}", + exc_info=True, + ) + raise + + +# ─── 具名装饰器 ──────────────────────────────────────────────────────────── + +def check_workspace_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "workspace_quota", "workspace") + return func(*args, **kwargs) + return wrapper + + +def check_skill_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "skill_quota", "skill") + return func(*args, **kwargs) + return wrapper + + +def check_app_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "app_quota", "app") + return func(*args, **kwargs) + return wrapper + + +def check_knowledge_capacity_quota(func: Callable) -> Callable: + @wraps(func) + async def async_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return await func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return await func(*args, **kwargs) + _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + return func(*args, **kwargs) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + +def check_memory_engine_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine") + return func(*args, **kwargs) + return wrapper + + +def check_end_user_quota(func: Callable) -> Callable: + @wraps(func) + async def async_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return await func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return await func(*args, **kwargs) + _check_quota(db, tenant_id, "end_user_quota", "end_user") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return func(*args, **kwargs) + _check_quota(db, tenant_id, "end_user_quota", "end_user") + return func(*args, **kwargs) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + +def check_ontology_project_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project") + return func(*args, **kwargs) + return wrapper + + +def check_model_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "model_quota", "model") + return func(*args, **kwargs) + return wrapper + + +def check_model_activation_quota(func: Callable) -> Callable: + """模型激活时的配额检查装饰器""" + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + + model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) + model_data = kwargs.get("model_data") + + if not model_id or not model_data: + logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") + return func(*args, **kwargs) + + if model_data.is_active is True: + try: + from app.models.models_model import ModelConfig + from app.services.model_service import ModelConfigService + + existing_model = ModelConfigService.get_model_by_id( + db=db, + model_id=model_id, + tenant_id=user.tenant_id + ) + + if not existing_model.is_active: + logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}") + _check_quota(db, user.tenant_id, "model_quota", "model") + except Exception as e: + logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}") + raise + + return func(*args, **kwargs) + return wrapper + + +def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None): + """通用配额检查装饰器,支持自定义使用量获取函数""" + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) + return func(*args, **kwargs) + return wrapper + return decorator + + +# ─── 配额使用统计 ──────────────────────────────────────────────────────────── + +def get_quota_usage(db: Session, tenant_id: UUID) -> dict: + """获取租户所有配额的使用情况""" + quota_config = _get_quota_config(db, tenant_id) + if not quota_config: + return {} + + repo = QuotaUsageRepository(db) + + def pct(used, limit): + return round(used / limit * 100, 1) if limit else None + + workspace_count = repo.count_workspaces(tenant_id) + skill_count = repo.count_skills(tenant_id) + app_count = repo.count_apps(tenant_id) + knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id) + memory_count = repo.count_memory_engines(tenant_id) + end_user_count = repo.count_end_users(tenant_id) + model_count = repo.count_models(tenant_id) + ontology_count = repo.count_ontology_projects(tenant_id) + + api_ops_current = 0 + try: + from app.core.config import settings + import redis + _now = time.time() + _rk = f"rate_limit:tenant_qps:{tenant_id}" + _r = redis.StrictRedis( + host=settings.REDIS_HOST, port=settings.REDIS_PORT, + db=settings.REDIS_DB, password=settings.REDIS_PASSWORD, + decode_responses=True + ) + api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf")) + except Exception: + pass + + return { + "workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))}, + "skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))}, + "app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))}, + "knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"}, + "memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))}, + "end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))}, + "ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))}, + "model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))}, + "api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"}, + } diff --git a/api/app/core/quota_stub.py b/api/app/core/quota_stub.py new file mode 100644 index 00000000..577dfadb --- /dev/null +++ b/api/app/core/quota_stub.py @@ -0,0 +1,36 @@ +""" +配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现 + +所有配额检查逻辑统一在 core 层实现,两个版本共用: +- 社区版:从 default_free_plan.py 读取配额限制 +- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件 +""" +from app.core.quota_manager import ( + check_workspace_quota, + check_skill_quota, + check_app_quota, + check_knowledge_capacity_quota, + check_memory_engine_quota, + check_end_user_quota, + check_ontology_project_quota, + check_model_quota, + check_model_activation_quota, + get_quota_usage, + _check_quota, + QuotaUsageRepository, +) + +__all__ = [ + "check_workspace_quota", + "check_skill_quota", + "check_app_quota", + "check_knowledge_capacity_quota", + "check_memory_engine_quota", + "check_end_user_quota", + "check_ontology_project_quota", + "check_model_quota", + "check_model_activation_quota", + "get_quota_usage", + "_check_quota", + "QuotaUsageRepository", +] diff --git a/api/app/core/rag/app/naive.py b/api/app/core/rag/app/naive.py index 72272347..312216dd 100644 --- a/api/app/core/rag/app/naive.py +++ b/api/app/core/rag/app/naive.py @@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, excel_parser = ExcelParser() if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true": sections = [(_, "") for _ in excel_parser.html(binary, 12) if _] - parser_config["chunk_token_num"] = 0 else: sections = [(_, "") for _ in excel_parser(binary) if _] - parser_config["chunk_token_num"] = 12800 + callback(0.8, "Finish parsing.") + # Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分 + chunks = [s for s, _ in sections] + res.extend(tokenize_chunks(chunks, doc, is_english, None)) + res.extend(embed_res) + res.extend(url_res) + return res elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") diff --git a/api/app/core/rag/deepdoc/parser/excel_parser.py b/api/app/core/rag/deepdoc/parser/excel_parser.py index d66a21a8..c3999be9 100644 --- a/api/app/core/rag/deepdoc/parser/excel_parser.py +++ b/api/app/core/rag/deepdoc/parser/excel_parser.py @@ -232,14 +232,14 @@ class RAGExcelParser: t = str(ti[i].value) if i < len(ti) else "" t += (":" if t else "") + str(c.value) fields.append(t) - line = "; ".join(fields) + line = "\n".join(fields) if sheetname.lower().find("sheet") < 0: - line += " ——" + sheetname + line += "\n——" + sheetname res.append(line) else: # 只有表头的情况 if header_fields: - line = "; ".join(header_fields) + line = "\n".join(header_fields) if sheetname.lower().find("sheet") < 0: line += " ——" + sheetname res.append(line) diff --git a/api/app/core/rag/llm/embedding_model.py b/api/app/core/rag/llm/embedding_model.py index 22e35a15..59210054 100644 --- a/api/app/core/rag/llm/embedding_model.py +++ b/api/app/core/rag/llm/embedding_model.py @@ -50,7 +50,9 @@ class OpenAIEmbed(Base): def encode(self, texts: list): # OpenAI requires batch size <=16 batch_size = 16 - texts = [truncate(t, 8191) for t in texts] + # Use 8000 instead of 8191 to leave safety margin for tokenizer differences + # between cl100k_base (used by truncate) and the actual embedding model + texts = [truncate(t, 8000) for t in texts] ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): @@ -63,7 +65,7 @@ class OpenAIEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) + res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) return np.array(res.data[0].embedding), self.total_token_count(res) @@ -79,6 +81,7 @@ class LocalAIEmbed(Base): def encode(self, texts: list): batch_size = 16 + texts = [truncate(t, 8000) for t in texts] ress = [] for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) @@ -173,6 +176,7 @@ class XinferenceEmbed(Base): def encode(self, texts: list): batch_size = 16 + texts = [truncate(t, 8000) for t in texts] ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): @@ -188,7 +192,7 @@ class XinferenceEmbed(Base): def encode_queries(self, text): res = None try: - res = self.client.embeddings.create(input=[text], model=self.model_name) + res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name) return np.array(res.data[0].embedding), self.total_token_count(res) except Exception as _e: log_exception(_e, res) diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 2fda6b8b..d37e2dcd 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool): return { "datetime": input_value, "timezone": timezone_str, - "timestamp": int(dt.timestamp()) * 1000, + "timestamp": int(dt.timestamp() * 1000), "iso_format": dt.isoformat(), - "result_data": int(dt.timestamp()) * 1000 + "result_data": int(dt.timestamp() * 1000) } def _calculate_datetime(self, kwargs) -> dict: diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 08d10e22..b34efe15 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -201,12 +201,15 @@ class VariablePool: @staticmethod def _extract_field(struct: "VariableStruct", field: str | None) -> Any: - """If field is given, drill into a dict/object variable's value.""" + """If field is given, drill into a dict/object/array[file] variable's value.""" if field is None: return struct.instance.get_value() value = struct.instance.get_value() + # array[file]: extract the field from every element, return a list + if isinstance(value, list): + return [item.get(field) if isinstance(item, dict) else getattr(item, field, None) for item in value] if not isinstance(value, dict): - raise KeyError(f"Variable is not an object, cannot access field '{field}'") + raise KeyError(f"Variable is not an object or array, cannot access field '{field}'") return value.get(field) def get_instance( diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index cf7ac976..1633b9c7 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -28,86 +28,135 @@ class IterationRuntime: def __init__( self, - start_id: str, stream: bool, - graph: CompiledStateGraph, node_id: str, config: dict[str, Any], state: WorkflowState, variable_pool: VariablePool, - child_variable_pool: VariablePool, + cycle_nodes: list, + cycle_edges: list, ): """ Initialize the iteration runtime. Args: - graph: Compiled workflow graph capable of async invocation. - node_id: Unique identifier of the loop node. - config: Dictionary containing iteration node configuration. - state: Current workflow state at the point of iteration. + stream: Whether to run in streaming mode. When True, each iteration + uses graph.astream and emits cycle_item events in real time. + When False, graph.ainvoke is used instead. + node_id: The unique identifier of the iteration node in the workflow. + Also used as the variable namespace for item/index inside + the subgraph (e.g. {{ node_id.item }}). + config: Raw configuration dict for the iteration node, parsed into + IterationNodeConfig. Controls input/output variable selectors, + parallel execution settings, and output flattening. + state: The parent workflow state at the point the iteration node is + entered. Each task receives a copy of this state as its + starting point. + variable_pool: The parent VariablePool containing all variables available + at the time the iteration node executes, including sys.*, + conv.*, and outputs from upstream nodes. Used as the source + for deep-copying into each task's independent child pool. + cycle_nodes: List of node config dicts belonging to this iteration's + subgraph (i.e. nodes whose cycle field equals node_id). + Passed to GraphBuilder when constructing each task's subgraph. + cycle_edges: List of edge config dicts connecting nodes within the subgraph. + Passed to GraphBuilder alongside cycle_nodes. """ - self.start_id = start_id self.stream = stream - self.graph = graph self.state = state self.node_id = node_id self.typed_config = IterationNodeConfig(**config) self.looping = True self.variable_pool = variable_pool - self.child_variable_pool = child_variable_pool + self.cycle_nodes = cycle_nodes + self.cycle_edges = cycle_edges self.event_write = get_stream_writer() - self.checkpoint = RunnableConfig( - configurable={ - "thread_id": uuid.uuid4() - } - ) self.output_value = None self.result: list = [] - async def _init_iteration_state(self, item, idx): + def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]: """ - Initialize a per-iteration copy of the workflow state. + Build an independent compiled subgraph for a single iteration task. - Args: - item: Current element from the input array for this iteration. - idx: Index of the element in the input array. + Each call creates a brand-new VariablePool by deep-copying the parent pool, + then passes it to GraphBuilder. GraphBuilder binds this pool to every node's + execution closure at build time, so the pool and the subgraph always reference + the same object. This is the key design invariant: item/index written into the + pool after build will be visible to all nodes inside the subgraph. Returns: - A copy of the workflow state with iteration-specific variables set. + graph: The compiled LangGraph subgraph ready for invocation. + child_pool: The VariablePool bound to this subgraph's node closures. + Callers must write item/index into this pool before invoking + the graph, and read output from it after invocation. + start_node_id: The ID of the CYCLE_START node inside the subgraph, + used to set the initial activation signal in workflow state. """ - loopstate = WorkflowState( - **self.state + from app.core.workflow.engine.graph_builder import GraphBuilder + child_pool = VariablePool() + child_pool.copy(self.variable_pool) + builder = GraphBuilder( + {"nodes": self.cycle_nodes, "edges": self.cycle_edges}, + stream=self.stream, + variable_pool=child_pool, + cycle=self.node_id, ) - self.child_variable_pool.copy(self.variable_pool) - await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) - await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True) - loopstate["node_outputs"][self.node_id] = { - "item": item, - "index": idx, - } + graph = builder.build() + return graph, builder.variable_pool, builder.start_node_id + + async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str): + """ + Initialize the workflow state for a single iteration. + + Writes the current item and its index into child_pool under the iteration + node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them + accessible to downstream nodes inside the subgraph via variable selectors. + + Also prepares a copy of the parent workflow state with: + - node_outputs[node_id] set to {item, index} so the state snapshot is consistent + with the pool values. + - looping flag set to 1 (active) to signal the subgraph is inside a cycle. + - activate[start_id] set to True to trigger the CYCLE_START node. + + Args: + item: The current element from the input array. + idx: The zero-based index of this element in the input array. + child_pool: The VariablePool bound to this iteration's subgraph. + Must be the same object returned by _build_child_graph. + start_id: The ID of the CYCLE_START node inside the subgraph. + + Returns: + A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream. + """ + loopstate = WorkflowState(**self.state) + await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) + await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True) + loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx} loopstate["looping"] = 1 - loopstate["activate"][self.start_id] = True + loopstate["activate"][start_id] = True return loopstate - def merge_conv_vars(self): - self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables["conv"] - ) + def _merge_conv_vars(self, child_pool: VariablePool): + self.variable_pool.variables["conv"].update(child_pool.variables["conv"]) async def run_task(self, item, idx): """ Execute a single iteration asynchronously. + Each task builds its own subgraph so the variable pool closure is independent. - Args: - item: The input element for this iteration. - idx: The index of this iteration. + Returns: + Tuple of (idx, output, result, child_pool, stopped) """ + graph, child_pool, start_id = self._build_child_graph() + checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()}) + init_state = await self._init_iteration_state(item, idx, child_pool, start_id) + if self.stream: - async for event in self.graph.astream( - await self._init_iteration_state(item, idx), + async for event in graph.astream( + init_state, stream_mode=["debug"], - config=self.checkpoint + config=checkpoint ): if isinstance(event, tuple) and len(event) == 2: mode, data = event @@ -117,7 +166,6 @@ class IterationRuntime: event_type = data.get("type") payload = data.get("payload", {}) node_name = payload.get("name") - if node_name and node_name.startswith("nop"): continue if event_type == "task_result": @@ -140,17 +188,13 @@ class IterationRuntime: "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") } }) - result = self.graph.get_state(config=self.checkpoint).values + result = graph.get_state(config=checkpoint).values else: - result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) - output = self.child_variable_pool.get_value(self.output_value) - if isinstance(output, list) and self.typed_config.flatten: - self.result.extend(output) - else: - self.result.append(output) - if result["looping"] == 2: - self.looping = False - return result + result = await graph.ainvoke(init_state) + + output = child_pool.get_value(self.output_value) + stopped = result["looping"] == 2 + return idx, output, result, child_pool, stopped def _create_iteration_tasks(self, array_obj, idx): """ @@ -196,16 +240,32 @@ class IterationRuntime: tasks = self._create_iteration_tasks(array_obj, idx) logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count - child_state.extend(await asyncio.gather(*tasks)) - self.merge_conv_vars() + batch = await asyncio.gather(*tasks) + # Sort by idx to preserve order, then collect results + batch_sorted = sorted(batch, key=lambda x: x[0]) + for _, output, result, child_pool, stopped in batch_sorted: + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + child_state.append(result) + self._merge_conv_vars(child_pool) + if stopped: + self.looping = False else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] - result = await self.run_task(item, idx) - self.merge_conv_vars() + _, output, result, child_pool, stopped = await self.run_task(item, idx) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + self._merge_conv_vars(child_pool) child_state.append(result) + if stopped: + self.looping = False idx += 1 logger.info(f"Iteration node {self.node_id}: execution completed") return { diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 68c83025..002c34df 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode): return cycle_nodes, cycle_edges - def build_graph(self): + def build_graph(self, variable_pool: VariablePool): """ Build and compile the internal subgraph for this cycle node. @@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode): from app.core.workflow.engine.graph_builder import GraphBuilder self.child_variable_pool = VariablePool() + self.child_variable_pool.copy(variable_pool) builder = GraphBuilder( { "nodes": self.cycle_nodes, @@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ - self.build_graph() if self.node_type == NodeType.LOOP: + self.build_graph(variable_pool) return await LoopRuntime( start_id=self.start_node_id, stream=False, @@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode): ).run() if self.node_type == NodeType.ITERATION: return await IterationRuntime( - start_id=self.start_node_id, stream=False, - graph=self.graph, node_id=self.node_id, config=self.config, state=state, variable_pool=variable_pool, - child_variable_pool=self.child_variable_pool + cycle_nodes=self.cycle_nodes, + cycle_edges=self.cycle_edges, ).run() raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): - self.build_graph() if self.node_type == NodeType.LOOP: + self.build_graph(variable_pool) yield { "__final__": True, "result": await LoopRuntime( @@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode): yield { "__final__": True, "result": await IterationRuntime( - start_id=self.start_node_id, stream=True, - graph=self.graph, node_id=self.node_id, config=self.config, state=state, variable_pool=variable_pool, - child_variable_pool=self.child_variable_pool + cycle_nodes=self.cycle_nodes, + cycle_edges=self.cycle_edges, ).run() } return diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index e1b84f0c..72474436 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel): @classmethod def validate_data(cls, v, info): content_type = info.data.get("content_type") - if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData): - raise ValueError("When content_type is 'form-data', data must be of type HttpFormData") + if content_type == HttpContentType.FROM_DATA and ( + not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)): + raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData") elif content_type in [HttpContentType.JSON] and not isinstance(v, str): raise ValueError("When content_type is JSON, data must be of type str") elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict): diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 086bee4a..783c230b 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode): )) case HttpContentType.FROM_DATA: data = {} - content["files"] = {} + files = [] for item in self.typed_config.body.data: + key = self._render_template(item.key, variable_pool) if item.type == "text": - data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, - variable_pool) + data[key] = self._render_template(item.value, variable_pool) elif item.type == "file": - content["files"][self._render_template(item.key, variable_pool)] = ( - uuid.uuid4().hex, - await variable_pool.get_instance(item.value).get_content() - ) + file_instance = variable_pool.get_instance(item.value) + if isinstance(file_instance, ArrayVariable): + for v in file_instance.value: + if isinstance(v, FileVariable): + files.append((key, (uuid.uuid4().hex, await v.get_content()))) + elif isinstance(file_instance, FileVariable): + files.append((key, (uuid.uuid4().hex, await file_instance.get_content()))) content["data"] = data + if files: + content["files"] = files case HttpContentType.BINARY: content["files"] = [] file_instence = variable_pool.get_instance(self.typed_config.body.data) diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 638e4b2d..4a5b3860 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -6,6 +6,30 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType +class SubVariableConditionItem(BaseModel): + """A single condition on a file object's field, used inside sub_variable_condition.""" + key: str = Field(..., description="Field name of the file object, e.g. type, size, name") + operator: ComparisonOperator = Field(..., description="Comparison operator") + value: Any = Field(default=None, description="Value to compare with, or variable selector when input_type=variable") + input_type: ValueInputType = Field(default=ValueInputType.CONSTANT, description="constant or variable") + + @field_validator("input_type", mode="before") + @classmethod + def lower_input_type(cls, v): + if isinstance(v, str): + try: + return ValueInputType(v.lower()) + except ValueError: + raise ValueError(f"Invalid input_type: {v}") + return v + + +class SubVariableCondition(BaseModel): + """Sub-conditions applied to each file element in an array[file] variable.""" + logical_operator: LogicOperator = Field(default=LogicOperator.AND) + conditions: list[SubVariableConditionItem] = Field(default_factory=list) + + class ConditionDetail(BaseModel): operator: ComparisonOperator = Field( ..., @@ -14,12 +38,12 @@ class ConditionDetail(BaseModel): left: str = Field( ..., - description="Value to compare against" + description="Variable selector, e.g. {{sys.files}}" ) right: Any = Field( default=None, - description="Value to compare with" + description="Value to compare with (unused when sub_variable_condition is set)" ) input_type: ValueInputType = Field( @@ -27,6 +51,11 @@ class ConditionDetail(BaseModel): description="Value input type for comparison" ) + sub_variable_condition: SubVariableCondition | None = Field( + default=None, + description="Sub-conditions for array[file] fields. When set, operator must be contains/not_contains." + ) + @field_validator("input_type", mode="before") @classmethod def lower_input_type(cls, v): @@ -39,16 +68,19 @@ class ConditionDetail(BaseModel): class ConditionBranchConfig(BaseModel): - """Configuration for a conditional branch""" + """Configuration for a conditional branch. + + logical_operator controls how all expressions are combined (AND/OR). + """ logical_operator: LogicOperator = Field( default=LogicOperator.AND, - description="Logical operator used to combine multiple condition expressions" + description="Logical operator used to combine all conditions" ) expressions: list[ConditionDetail] = Field( - ..., - description="List of condition expressions within this branch" + default_factory=list, + description="List of conditions within this branch" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index ec46b20b..c4d3a0e6 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -7,7 +7,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType from app.core.workflow.nodes.if_else import IfElseNodeConfig -from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance +from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance, ArrayFileContainsOperator from app.core.workflow.variable.base_variable import VariableType logger = logging.getLogger(__name__) @@ -90,11 +90,9 @@ class IfElseNode(BaseNode): list[str]: A list of Python boolean expression strings, ordered by branch priority. """ - branch_index = 0 conditions = [] for case_branch in self.typed_config.cases: - branch_index += 1 branch_result = [] for expression in case_branch.expressions: pattern = r"\{\{\s*(.*?)\s*\}\}" @@ -103,13 +101,18 @@ class IfElseNode(BaseNode): left_value = self.get_variable(left_string, variable_pool) except KeyError: left_value = None - evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( - variable_pool, - expression.left, - expression.right, - expression.input_type - ) + + if expression.sub_variable_condition is not None and isinstance(left_value, list): + evaluator = ArrayFileContainsOperator(left_value, expression.sub_variable_condition, variable_pool) + else: + evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( + variable_pool, + expression.left, + expression.right, + expression.input_type + ) branch_result.append(self._evaluate(expression.operator, evaluator)) + if case_branch.logical_operator == LogicOperator.AND: conditions.append(all(branch_result)) else: diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 771262c1..b815c80f 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig): description="Top-p 采样参数" ) + json_output: bool = Field( + default=False, + description="是否以 JSON 格式输出" + ) + frequency_penalty: float | None = Field( default=None, ge=-2.0, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bb87c845..db7f1009 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -22,6 +22,7 @@ from app.db import get_db_context from app.models import ModelType from app.schemas.model_schema import ModelInfo from app.services.model_service import ModelConfigService +from app.models.models_model import ModelProvider logger = logging.getLogger(__name__) @@ -126,7 +127,11 @@ class LLMNode(BaseNode): # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True - extra_params = {"streaming": stream} if stream else {} + extra_params: dict[str, Any] = {"streaming": stream} if stream else {} + if self.typed_config.temperature is not None: + extra_params["temperature"] = self.typed_config.temperature + if self.typed_config.max_tokens is not None: + extra_params["max_tokens"] = self.typed_config.max_tokens llm = RedBearLLM( RedBearModelConfig( @@ -135,7 +140,9 @@ class LLMNode(BaseNode): api_key=model_info.api_key, base_url=model_info.api_base, extra_params=extra_params, - is_omni=model_info.is_omni + is_omni=model_info.is_omni, + capability=model_info.capability, + json_output=self.typed_config.json_output, ), type=model_info.model_type ) @@ -218,6 +225,19 @@ class LLMNode(BaseNode): rendered = self._render_template(prompt_template, variable_pool) self.messages = [{"role": "user", "content": rendered}] + # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入 + # VOLCANO 模型不支持 response_format,同样需要 system prompt 注入 + need_json_prompt = self.typed_config.json_output and ( + (model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni) + or model_info.provider.lower() == ModelProvider.VOLCANO + ) + if need_json_prompt: + system_msg = next((m for m in self.messages if m["role"] == "system"), None) + if system_msg: + system_msg["content"] += "\n请以JSON格式输出。" + else: + self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"}) + return llm async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 14fc9d9f..62eebbfe 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -395,11 +395,73 @@ class NoneObjectComparisonOperator: return lambda *args, **kwargs: False +class ArrayFileContainsOperator: + """Handles contains/not_contains on array[file] with sub_variable_condition.""" + + def __init__(self, left_value: list[dict], sub_variable_condition: Any, pool: VariablePool | None = None): + self.left_value = left_value + self.sub_variable_condition = sub_variable_condition + self.pool = pool + + def _resolve_value(self, cond: Any) -> Any: + if cond.input_type == ValueInputType.VARIABLE and self.pool is not None: + pattern = r"\{\{\s*(.*?)\s*\}\}" + selector = re.sub(pattern, r"\1", str(cond.value)).strip() + return self.pool.get_value(selector, default=None, strict=False) + return cond.value + + def _match_item(self, file_item: dict) -> bool: + results = [] + for cond in self.sub_variable_condition.conditions: + field_val = file_item.get(cond.key) + expected = self._resolve_value(cond) + result = self._eval_sub(field_val, cond.operator.value, expected) + results.append(result) + if self.sub_variable_condition.logical_operator.value == "and": + return all(results) + return any(results) + + @staticmethod + def _eval_sub(field_val: Any, op: str, expected: Any) -> bool: + if field_val is None: + return op == "empty" + match op: + case "eq": return str(field_val) == str(expected) + case "ne": return str(field_val) != str(expected) + case "contains": return isinstance(field_val, str) and str(expected) in field_val + case "not_contains": return isinstance(field_val, str) and str(expected) not in field_val + case "in": return field_val in (expected if isinstance(expected, list) else [expected]) + case "not_in": return field_val not in (expected if isinstance(expected, list) else [expected]) + case "gt": return isinstance(field_val, (int, float)) and field_val > float(expected) + case "ge": return isinstance(field_val, (int, float)) and field_val >= float(expected) + case "lt": return isinstance(field_val, (int, float)) and field_val < float(expected) + case "le": return isinstance(field_val, (int, float)) and field_val <= float(expected) + case "empty": return field_val in (None, "", 0) + case "not_empty": return field_val not in (None, "", 0) + case _: return False + + def contains(self) -> bool: + return any(self._match_item(f) for f in self.left_value if isinstance(f, dict)) + + def not_contains(self) -> bool: + return not self.contains() + + def empty(self) -> bool: + return not self.left_value + + def not_empty(self) -> bool: + return bool(self.left_value) + + def __getattr__(self, name): + return lambda *args, **kwargs: False + + CompareOperatorInstance = Union[ StringComparisonOperator, NumberComparisonOperator, BooleanComparisonOperator, ArrayComparisonOperator, + ArrayFileContainsOperator, ObjectComparisonOperator ] CompareOperatorType = Type[CompareOperatorInstance] diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 72c5c6a8..410f64c3 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -15,6 +15,7 @@ from app.services.tool_service import ToolService logger = logging.getLogger(__name__) TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}") +PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$") class ToolNode(BaseNode): @@ -52,13 +53,21 @@ class ToolNode(BaseNode): # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): - try: - rendered_value = self._render_template(param_template, variable_pool) - except Exception as e: - raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + if isinstance(param_template, str): + pure_match = PURE_VARIABLE_PATTERN.match(param_template) + if pure_match: + # 纯单变量引用直接取原始值,保留 int/bool/float 等类型 + rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False) + if rendered_value is None: + rendered_value = self._render_template(param_template, variable_pool) + elif TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, variable_pool) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + rendered_value = param_template else: - # 非模板参数(数字/布尔/普通字符串)直接保留原值 rendered_value = param_template rendered_parameters[param_name] = rendered_value diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 94f87287..2b849c94 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -84,7 +84,7 @@ class FileVariable(BaseVariable): total_bytes = 0 chunks = [] - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: async with client.stream("GET", self.value.url) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(8192): diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index a92b5629..c3fd82df 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -29,11 +29,8 @@ class Tenants(Base): contact_email = Column(String(255), nullable=True) # 联系人邮箱 contact_phone = Column(String(50), nullable=True) # 联系人电话 - # 租户套餐信息 - plan = Column(String(50), nullable=True) # 套餐类型 - plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间 - api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 - status = Column(String(50), nullable=True, default='active') # 租户状态 + # 租户套餐信息(只读,从 tenant_subscriptions 动态获取) + status = Column(String(50), nullable=True, default='active', server_default='active') # 租户状态 # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/repositories/implicit_emotions_storage_repository.py b/api/app/repositories/implicit_emotions_storage_repository.py index b6c40b40..b665924d 100644 --- a/api/app/repositories/implicit_emotions_storage_repository.py +++ b/api/app/repositories/implicit_emotions_storage_repository.py @@ -5,16 +5,9 @@ Implicit Emotions Storage Repository 事务由调用方控制,仓储层只使用 flush/refresh """ import logging -from datetime import date, datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Generator, Optional - -class TimeFilterUnavailableError(Exception): - """redis_client 不可用,无法执行时间轴筛选。 - - 调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。 - """ - import redis from sqlalchemy import exists, not_, select from sqlalchemy.orm import Session @@ -25,6 +18,13 @@ from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage logger = logging.getLogger(__name__) +class TimeFilterUnavailableError(Exception): + """redis_client 不可用,无法执行时间轴筛选。 + + 调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。 + """ + + class ImplicitEmotionsStorageRepository: """隐性记忆和情绪存储仓储类""" @@ -216,9 +216,7 @@ class ImplicitEmotionsStorageRepository: """ from sqlalchemy import String as SAString from sqlalchemy import cast - CST = timezone(timedelta(hours=8)) - now_cst = datetime.now(CST) - today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None) + today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) tomorrow_start = today_start + timedelta(days=1) offset = 0 while True: diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 3139b851..072be1e2 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -328,7 +328,7 @@ class MemoryConfigRepository: if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - + #TODO:部分更新没有用patch请求,是在Repository层中用先查再部分更新的方式实现的,后续可以考虑改成patch请求更符合RESTful设计原则 update_data = update.model_dump(exclude_unset=True) update_data.pop("config_id", None) diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 8c477d39..03870b4d 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -263,16 +263,15 @@ class ModelConfigRepository: raise @staticmethod - def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]: - """根据类型获取模型配置""" - db_logger.debug(f"根据类型查询模型配置: type={model_type}, tenant_id={tenant_id}, is_active={is_active}") - + def get_by_type(db: Session, model_types: List[ModelType], tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]: + """根据类型获取模型配置,支持多类型查询""" + db_logger.debug(f"根据类型查询模型配置: types={[t.value for t in model_types]}, tenant_id={tenant_id}, is_active={is_active}") + try: query = db.query(ModelConfig).options( joinedload(ModelConfig.api_keys) - ).filter(ModelConfig.type == model_type) - - # 添加租户过滤 + ).filter(ModelConfig.type.in_([t.value for t in model_types])) + if tenant_id: query = query.filter( or_( @@ -280,16 +279,18 @@ class ModelConfigRepository: ModelConfig.is_public ) ) - + if is_active: query = query.filter(ModelConfig.is_active) - - models = query.order_by(ModelConfig.name).all() + + query = query.filter(ModelConfig.is_composite == False) + + models = query.order_by(ModelConfig.created_at.desc()).all() db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}") return models - + except Exception as e: - db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}") + db_logger.error(f"根据类型查询模型配置失败: types={model_types} - {str(e)}") raise @staticmethod diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 03d51a7c..05a8c4b0 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -94,6 +94,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity END, e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END, e.aliases = CASE + // 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,知识抽取完全不写入 + WHEN entity.name IN ['用户', '我', 'User', 'I'] THEN e.aliases WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0 THEN CASE WHEN e.aliases IS NULL THEN entity.aliases diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index 2dd76b04..6874f9bf 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -297,6 +297,10 @@ def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]: """根据ID获取用户""" return UserRepository(db).get_user_by_id(user_id) +def get_user_by_id_regardless_active(db: Session, user_id: uuid.UUID) -> Optional[User]: + """根据ID获取用户(不过滤 is_active,用于启用/禁用场景)""" + return db.query(User).filter(User.id == user_id).first() + def get_user_by_email(db: Session, email: str) -> Optional[User]: """根据邮箱获取用户""" return UserRepository(db).get_user_by_email(email) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 5f73cde1..e93c513d 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -44,6 +44,8 @@ class FileInput(BaseModel): upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)") url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + name: Optional[str] = Field(None, description="文件名") + size: Optional[int] = Field(None, description="文件大小(字节)") _content = None @@ -243,6 +245,7 @@ class ModelParameters(BaseModel): stop: Optional[List[str]] = Field(default=None, description="停止序列") deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)") thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)") + json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)") class VariableDefinition(BaseModel): diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index ff62355f..4cc548f3 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -4,9 +4,10 @@ This module defines Pydantic schemas for the Memory API Service endpoints, including request validation and response structures for read and write operations. """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional +import uuid -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator class MemoryWriteRequest(BaseModel): @@ -110,6 +111,30 @@ class MemoryReadRequest(BaseModel): class MemoryWriteResponse(BaseModel): """Response schema for memory write operation. + Attributes: + task_id: Celery task ID for status polling + status: Initial task status (PENDING) + end_user_id: End user ID the write was submitted for + """ + task_id: str = Field(..., description="Celery task ID for polling") + status: str = Field(..., description="Task status: PENDING") + end_user_id: str = Field(..., description="End user ID") + + +class TaskStatusResponse(BaseModel): + """Response schema for task status check. + + Attributes: + status: Task status (PENDING, STARTED, SUCCESS, FAILURE, SKIPPED) + result: Task result data (available when status is SUCCESS or FAILURE) + """ + status: str = Field(..., description="Task status") + result: Optional[Dict[str, Any]] = Field(None, description="Task result when completed") + + +class MemoryWriteSyncResponse(BaseModel): + """Response schema for synchronous memory write. + Attributes: status: Operation status (success or failed) end_user_id: End user ID that was written to @@ -118,8 +143,8 @@ class MemoryWriteResponse(BaseModel): end_user_id: str = Field(..., description="End user ID") -class MemoryReadResponse(BaseModel): - """Response schema for memory read operation. +class MemoryReadSyncResponse(BaseModel): + """Response schema for synchronous memory read. Attributes: answer: Generated answer from memory retrieval @@ -128,12 +153,25 @@ class MemoryReadResponse(BaseModel): """ answer: str = Field(..., description="Generated answer") intermediate_outputs: List[Dict[str, Any]] = Field( - default_factory=list, + default_factory=list, description="Intermediate retrieval outputs" ) end_user_id: str = Field(..., description="End user ID") +class MemoryReadResponse(BaseModel): + """Response schema for memory read operation. + + Attributes: + task_id: Celery task ID for status polling + status: Initial task status (PENDING) + end_user_id: End user ID the read was submitted for + """ + task_id: str = Field(..., description="Celery task ID for polling") + status: str = Field(..., description="Task status: PENDING") + end_user_id: str = Field(..., description="End user ID") + + class CreateEndUserRequest(BaseModel): """Request schema for creating an end user. @@ -141,10 +179,12 @@ class CreateEndUserRequest(BaseModel): other_id: External user identifier (required) other_name: Display name for the end user memory_config_id: Optional memory config ID. If not provided, uses workspace default. + app_id: Optional app ID to bind the end user to. """ other_id: str = Field(..., description="External user identifier (required)") other_name: Optional[str] = Field("", description="Display name") memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.") + app_id: Optional[str] = Field(None, description="App ID to bind the end user to") @field_validator("other_id") @classmethod @@ -192,6 +232,7 @@ class MemoryConfigItem(BaseModel): created_at: Optional[str] = Field(None, description="Creation timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp") +# ========== V1 记忆配置管理接口 Schema ========== class ListConfigsResponse(BaseModel): """Response schema for listing memory configs. @@ -202,3 +243,203 @@ class ListConfigsResponse(BaseModel): """ configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs") total: int = Field(0, description="Total number of configs") + +class ConfigCreateRequest(BaseModel): + """Request schema for creating a new memory config.""" + config_name: str = Field(..., description="Configuration name") + config_desc: Optional[str] = Field("", description="Configuration description") + scene_id: uuid.UUID = Field(..., description="Associated ontology scene ID (UUID, required)") + + llm_id: Optional[str] = Field(None, description="LLM model configuration ID") + embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID") + rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID") + reflection_model_id: Optional[str] = Field(None, description="Reflection model ID") + emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID") + + @field_validator("config_name") + @classmethod + def validate_config_name(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_name is required and cannot be empty") + return v.strip() + +class ConfigUpdateRequest(BaseModel): + """Request schema for updating memory config basic info. + + Attributes: + config_id: Configuration UUID to update (required) + config_name: New configuration name + config_desc: New configuration description + scene_id: New associated ontology scene ID + """ + config_id: str = Field(..., description="Configuration ID to update") + config_name: Optional[str] = Field(None, description="Configuration name") + config_desc: Optional[str] = Field(None, description="Configuration description") + scene_id: Optional[uuid.UUID] = Field(None, description="Associated ontology scene ID") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + """Validate that config_id is not empty.""" + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ConfigUpdateExtractedRequest(BaseModel): + """Request schema for updating memory config extracted parameters. + + Attributes: + config_id: Configuration UUID to update (required) + llm_id: Optional LLM model configuration ID + audio_id: Optional audio model configuration ID + vision_id: Optional vision model configuration ID + video_id: Optional video model configuration ID + embedding_id: Optional embedding model configuration ID + rerank_id: Optional reranking model configuration ID + enable_llm_dedup_blockwise: Optional toggle for LLM decision deduplication + enable_llm_disambiguation: Optional toggle for LLM decision disambiguation + deep_retrieval: Optional toggle for deep retrieval + + t_type_strict: Optional float (0-1) for type strictness threshold + t_name_strict: Optional float (0-1) for name strictness threshold + t_overall: Optional float (0-1) for overall strictness threshold + state: Optional boolean for config active state + chunker_strategy: Optional string for memory chunking strategy + statement_granularity: Optional int (1-3) for statement extraction granularity + include_dialogue_context: Optional boolean for including dialogue context in retrieval + max_context: Optional int for maximum dialogue context length in characters + pruning_enabled: Optional boolean to enable intelligent semantic pruning + pruning_scene: Optional string for semantic pruning scene + pruning_threshold: Optional float (0-0.9) for semantic pruning threshold + enable_self_reflexion: Optional boolean to enable self-reflexion + iteration_period: Optional string for reflexion iteration period in hours (1, 3, 6, 12, 24) + reflexion_range: Optional string for reflexion range (partial or all) + baseline: Optional string for baseline (TIME/FACT/TIME-FACT) + + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + llm_id: Optional[str] = Field(None, description="LLM model configuration ID") + audio_id: Optional[str] = Field(None, description="Audio model ID") + vision_id: Optional[str] = Field(None, description="Vision model ID") + video_id: Optional[str] = Field(None, description="Video model ID") + embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID") + rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID") + enable_llm_dedup_blockwise: Optional[bool] = Field(None, description="Enable LLM decision deduplication") + enable_llm_disambiguation: Optional[bool] = Field(None, description="Enable LLM decision disambiguation") + deep_retrieval: Optional[bool] = Field(None, description="Deep retrieval toggle") + + t_type_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="type strictness threshold") + t_name_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="name strictness threshold") + t_overall: Optional[float] = Field(None, ge=0.0, le=1.0, description="overall strictness threshold") + state: Optional[bool] = Field(None, description="config active state") + # 句子提取 + chunker_strategy: Optional[str] = Field(None, description="memory chunking strategy") + statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="statement extraction granularity") + include_dialogue_context: Optional[bool] = Field(None, description="whether to include dialogue context in retrieval") + max_context: Optional[int] = Field(None, gt=100, description="maximum dialogue context length in characters") + # 剪枝配置:与 runtime.json 中 pruning 段对应 + pruning_enabled: Optional[bool] = Field(None, description="whether to enable intelligent semantic pruning") + pruning_scene: Optional[str] = Field(None, description="semantic pruning scene") + pruning_threshold: Optional[float] = Field(None, ge=0.0, le=0.9, description="semantic pruning threshold (0-0.9)") + enable_self_reflexion: Optional[bool] = Field(None, description="whether to enable self-reflexion") + iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(None, description="reflexion iteration period in hours (1, 3, 6, 12, 24)") + reflexion_range: Optional[Literal["partial", "all"]] = Field(None, description="reflexion range: partial/all") + baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(None, description="baseline: TIME/FACT/TIME-FACT") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ConfigUpdateForgettingRequest(BaseModel): + """Request schema for updating memory config forgetting parameters. + + Attributes: + config_id: Configuration UUID to update (required) + decay_constant: Decay constant for forgetting + lambda_time: Time decay parameter + lambda_mem: Memory decay parameter + offset: Offset for forgetting curve + max_history_length: Maximum history length to consider for forgetting + forgetting_threshold: Threshold for forgetting + min_days_since_access: Minimum days since last access to trigger forgetting + enable_llm_summary: Whether to use LLM-generated summaries for forgetting + max_merge_batch_size: Maximum batch size for merging nodes during forgetting + forgetting_interval_hours: Interval in hours for periodic forgetting + + """ + model_config = ConfigDict(populate_by_name=True, extra="forbid") + config_id: str = Field(..., description="Configuration ID (UUID)") + decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="Decay constant for forgetting") + lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="Time decay parameter") + lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="Memory decay parameter") + offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="Offset for forgetting curve") + max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="Maximum history length to consider for forgetting") + forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="Forgetting threshold") + min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="Minimum days since last access to trigger forgetting") + enable_llm_summary: Optional[bool] = Field(None, description="Whether to use LLM-generated summaries for forgetting") + max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="Maximum batch size for merging nodes during forgetting") + forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="Interval in hours for periodic forgetting") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class EmotionConfigUpdateRequest(BaseModel): + """Request schema for updating memory config emotion parameters. + + Attributes: + config_id: Configuration UUID to update (required) + emotion_enabled: Whether to enable emotion extraction + emotion_model_id: Emotion analysis model ID + emotion_extract_keywords: Whether to extract emotion keywords + emotion_min_intensity: Minimum emotion intensity threshold (0.0-1.0) + emotion_enable_subject: Whether to enable subject classification for emotions + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + emotion_enabled: bool = Field(..., description="Whether to enable emotion extraction") + emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID") + emotion_extract_keywords: bool = Field(..., description="Whether to extract emotion keywords") + emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="Minimum emotion intensity threshold") + emotion_enable_subject: bool = Field(..., description="Whether to enable subject classification for emotions") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ReflectionConfigUpdateRequest(BaseModel): + """Request schema for updating memory config reflection parameters. + + Attributes: + config_id: Configuration UUID to update (required) + reflection_enabled: Whether to enable self-reflection + reflection_period_in_hours: Reflection iteration period in hours + reflexion_range: Reflection range (partial or all) + baseline: Baseline for reflection (TIME/FACT/TIME-FACT) + reflection_model_id: Reflection model ID + memory_verify: Whether to enable memory verification + quality_assessment: Whether to enable quality assessment + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + reflection_enabled: bool = Field(..., description="Whether to enable self-reflection") + reflection_period_in_hours: str = Field(..., description="Reflection iteration period in hours") + reflexion_range: Literal["partial", "all"] = Field(..., description="Reflection range: partial/all") + baseline: Literal["TIME", "FACT", "TIME-FACT"] = Field(..., description="Baseline: TIME/FACT/TIME-FACT") + reflection_model_id: str = Field(..., description="Reflection model ID") + memory_verify: bool = Field(..., description="Whether to enable memory verification") + quality_assessment: bool = Field(..., description="Whether to enable quality assessment") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index bfcf6337..24dddd80 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -291,7 +291,7 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 pruning_threshold: Optional[float] = Field( None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)" ) - + #TODO:萃取引擎的更新的更新会带有反思引擎的参数,需判断业务是否需要,不需要可以重构 # 反思配置 enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思") iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field( diff --git a/api/app/services/api_key_service.py b/api/app/services/api_key_service.py index a49e8fe0..07d55198 100644 --- a/api/app/services/api_key_service.py +++ b/api/app/services/api_key_service.py @@ -248,6 +248,35 @@ class RateLimiterService: def __init__(self): self.redis = aio_redis + async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: + """ + 按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit + """ + now = time.time() + window_start = now - 1 # 1 秒窗口 + key = f"rate_limit:tenant_qps:{tenant_id}" + + async with self.redis.pipeline() as pipe: + # 清理 1 秒前的旧记录 + pipe.zremrangebyscore(key, 0, window_start) + # 加入当前请求(score=时间戳,member=时间戳+随机数保证唯一) + pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now}) + # 统计窗口内请求数 + pipe.zcard(key) + # 设置 key 过期(2 秒后自动清理) + pipe.expire(key, 2) + results = await pipe.execute() + + current = results[2] + remaining = max(0, limit - current) + reset_time = int(now) + 1 + + return current <= limit, { + "limit": limit, + "remaining": remaining, + "reset": reset_time, + } + async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: """ 检查QPS限制 diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index ec0c4b79..56e25713 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -26,6 +26,7 @@ from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService +from app.models.file_metadata_model import FileMetadata logger = get_business_logger() @@ -119,6 +120,7 @@ class AppChatService: tools=tools, deep_thinking=model_parameters.get("deep_thinking", False), thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), capability=api_key_obj.capability or [], ) @@ -218,11 +220,29 @@ class AppChatService: "reasoning_content": result.get("reasoning_content") } if files: + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: - # url = await MultimodalService(self.db).get_file_url(f) + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "name": name, + "size": size, + "file_type": f.file_type, }) if processed_files: @@ -373,6 +393,7 @@ class AppChatService: streaming=True, deep_thinking=model_parameters.get("deep_thinking", False), thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), capability=api_key_obj.capability or [], ) @@ -509,10 +530,29 @@ class AppChatService: } if files: + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "name": name, + "size": size, + "file_type": f.file_type, }) if processed_files: human_meta["history_files"] = { diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index 8c198be4..26e4098c 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -14,12 +14,14 @@ from app.models.app_model import App, AppType from app.models.appshare_model import AppShare from app.models.app_release_model import AppRelease from app.models.knowledge_model import Knowledge +from app.models.knowledgeshare_model import KnowledgeShare from app.models.models_model import ModelConfig from app.models.tool_model import ToolConfig as ToolConfigModel from app.models.skill_model import Skill from app.models.workflow_model import WorkflowConfig from app.services.workflow_service import WorkflowService from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter +from app.core.workflow.nodes.enums import NodeType from app.models.memory_config_model import MemoryConfig as MemoryConfigModel @@ -73,15 +75,14 @@ class AppDslService: AppType.MULTI_AGENT: "multi_agent_config", AppType.WORKFLOW: "workflow" }.get(app.type, "config") - config_data = self._enrich_release_config(app.type, release.config or {}) + config_data = self._enrich_release_config(app.type, release.config or {}, release.default_model_config_id) dsl = {**meta, "app": app_meta, config_key: config_data} return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml" - def _enrich_release_config(self, app_type: str, cfg: dict) -> dict: + def _enrich_release_config(self, app_type: str, cfg: dict, default_model_config_id=None) -> dict: if app_type == AppType.AGENT: enriched = {**cfg} - if "default_model_config_id" in cfg: - enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"]) + enriched["default_model_config_ref"] = self._model_ref(default_model_config_id) if "knowledge_retrieval" in cfg: enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"]) if "tools" in cfg: @@ -91,8 +92,7 @@ class AppDslService: return enriched if app_type == AppType.MULTI_AGENT: enriched = {**cfg} - if "default_model_config_id" in cfg: - enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"]) + enriched["default_model_config_ref"] = self._model_ref(default_model_config_id) if "master_agent_id" in cfg: enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"]) if "sub_agents" in cfg: @@ -229,8 +229,11 @@ class AppDslService: workspace_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID, + app_id: Optional[uuid.UUID] = None, ) -> tuple[App, list[str]]: - """解析 DSL,创建应用及配置,返回 (new_app, warnings)""" + """解析 DSL,创建或覆盖应用配置,返回 (app, warnings)。 + app_id 不为空时:校验类型一致后覆盖配置;为空时创建新应用。 + """ app_meta = dsl.get("app", {}) app_type = app_meta.get("type") if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW): @@ -239,6 +242,9 @@ class AppDslService: warnings: list[str] = [] now = datetime.datetime.now() + if app_id is not None: + return self._overwrite_dsl(dsl, app_id, app_type, workspace_id, tenant_id, warnings, now) + new_app = App( id=uuid.uuid4(), workspace_id=workspace_id, @@ -258,11 +264,57 @@ class AppDslService: self.db.add(new_app) self.db.flush() + self._write_config(new_app.id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=True) + + self.db.commit() + self.db.refresh(new_app) + return new_app, warnings + + def _overwrite_dsl( + self, + dsl: dict, + app_id: uuid.UUID, + app_type: str, + workspace_id: uuid.UUID, + tenant_id: uuid.UUID, + warnings: list, + now: datetime.datetime, + ) -> tuple[App, list[str]]: + """覆盖已有应用的配置,类型不一致时抛出异常""" + app = self.db.query(App).filter( + App.id == app_id, + App.workspace_id == workspace_id, + App.is_active.is_(True) + ).first() + if not app: + raise ResourceNotFoundException("应用", str(app_id)) + if app.type != app_type: + raise BusinessException( + f"YAML 类型 '{app_type}' 与应用类型 '{app.type}' 不一致,无法导入", + BizCode.BAD_REQUEST + ) + + self._write_config(app_id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=False) + + self.db.commit() + self.db.refresh(app) + return app, warnings + + def _write_config( + self, + app_id: uuid.UUID, + app_type: str, + dsl: dict, + workspace_id: uuid.UUID, + tenant_id: uuid.UUID, + warnings: list, + now: datetime.datetime, + create: bool, + ) -> None: + """写入(新建或覆盖)应用配置""" if app_type == AppType.AGENT: cfg = dsl.get("agent_config") or {} - self.db.add(AgentConfig( - id=uuid.uuid4(), - app_id=new_app.id, + fields = dict( system_prompt=cfg.get("system_prompt"), model_parameters=cfg.get("model_parameters"), default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings), @@ -272,16 +324,21 @@ class AppDslService: tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings), skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings), features=cfg.get("features", {}), - is_active=True, - created_at=now, updated_at=now, - )) + ) + if create: + self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) + else: + existing = self.db.query(AgentConfig).filter(AgentConfig.app_id == app_id).first() + if existing: + for k, v in fields.items(): + setattr(existing, k, v) + else: + self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) elif app_type == AppType.MULTI_AGENT: cfg = dsl.get("multi_agent_config") or {} - self.db.add(MultiAgentConfig( - id=uuid.uuid4(), - app_id=new_app.id, + fields = dict( orchestration_mode=cfg.get("orchestration_mode", "collaboration"), master_agent_name=cfg.get("master_agent_name"), model_parameters=cfg.get("model_parameters"), @@ -291,13 +348,24 @@ class AppDslService: routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings), execution_config=cfg.get("execution_config", {}), aggregation_strategy=cfg.get("aggregation_strategy", "merge"), - is_active=True, - created_at=now, updated_at=now, - )) + ) + if create: + self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) + else: + existing = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app_id).first() + if existing: + for k, v in fields.items(): + setattr(existing, k, v) + else: + self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) elif app_type == AppType.WORKFLOW: - adapter = MemoryBearAdapter(dsl) + raw_wf = dsl.get("workflow") or {} + raw_nodes = raw_wf.get("nodes") or [] + resolved_nodes = self._resolve_workflow_nodes(raw_nodes, tenant_id, workspace_id, warnings) + resolved_dsl = {**dsl, "workflow": {**raw_wf, "nodes": resolved_nodes}} + adapter = MemoryBearAdapter(resolved_dsl) if not adapter.validate_config(): raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST) result = adapter.parse_workflow() @@ -305,21 +373,39 @@ class AppDslService: warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}") for w in result.warnings: warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}") - wf = dsl.get("workflow") or {} - WorkflowService(self.db).create_workflow_config( - app_id=new_app.id, - nodes=[n.model_dump() for n in result.nodes], - edges=[e.model_dump() for e in result.edges], - variables=[v.model_dump() for v in result.variables], - execution_config=wf.get("execution_config", {}), - features=wf.get("features", {}), - triggers=wf.get("triggers", []), - validate=False, - ) - - self.db.commit() - self.db.refresh(new_app) - return new_app, warnings + wf_service = WorkflowService(self.db) + if create: + wf_service.create_workflow_config( + app_id=app_id, + nodes=[n.model_dump() for n in result.nodes], + edges=[e.model_dump() for e in result.edges], + variables=[v.model_dump() for v in result.variables], + execution_config=raw_wf.get("execution_config", {}), + features=raw_wf.get("features", {}), + triggers=raw_wf.get("triggers", []), + validate=False, + ) + else: + existing = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app_id).first() + if existing: + existing.nodes = [n.model_dump() for n in result.nodes] + existing.edges = [e.model_dump() for e in result.edges] + existing.variables = [v.model_dump() for v in result.variables] + existing.execution_config = raw_wf.get("execution_config", {}) + existing.features = raw_wf.get("features", {}) + existing.triggers = raw_wf.get("triggers", []) + existing.updated_at = now + else: + wf_service.create_workflow_config( + app_id=app_id, + nodes=[n.model_dump() for n in result.nodes], + edges=[e.model_dump() for e in result.edges], + variables=[v.model_dump() for v in result.variables], + execution_config=raw_wf.get("execution_config", {}), + features=raw_wf.get("features", {}), + triggers=raw_wf.get("triggers", []), + validate=False, + ) def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str: """生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用""" @@ -365,27 +451,63 @@ class AppDslService: def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]: if not ref: return None - kb = self.db.query(Knowledge).filter( - Knowledge.workspace_id == workspace_id, - Knowledge.name == ref.get("name") - ).first() - if not kb: - warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置") - return str(kb.id) if kb else None + kb_id = ref.get("id") + if kb_id: + try: + kb_uuid = uuid.UUID(str(kb_id)) + kb_share = self.db.query(KnowledgeShare).filter( + KnowledgeShare.target_workspace_id == workspace_id, + KnowledgeShare.source_kb_id == kb_uuid + ).first() + if kb_share: + kb = self.db.query(Knowledge).filter( + Knowledge.id == kb_share.target_kb_id + ).first() + if kb and kb.status == 1: + return str(kb_share.target_kb_id) + kb = self.db.query(Knowledge).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.id == kb_uuid, + Knowledge.status == 1 + ).first() + if kb: + return str(kb.id) + except (ValueError, AttributeError): + pass + warnings.append(f"知识库 '{kb_id}' 未匹配,已置空,请导入后手动配置") + return None def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]: if not ref: return None - q = self.db.query(ToolConfigModel).filter( - ToolConfigModel.tenant_id == tenant_id, - ToolConfigModel.name == ref.get("name") - ) - if ref.get("tool_type"): - q = q.filter(ToolConfigModel.tool_type == ref["tool_type"]) - t = q.first() - if not t: - warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置") - return str(t.id) if t else None + tool_id = ref.get("id") + tool_name = ref.get("name") + if tool_id: + try: + tool_uuid = uuid.UUID(str(tool_id)) + t = self.db.query(ToolConfigModel).filter( + ToolConfigModel.id == tool_uuid, + ToolConfigModel.tenant_id == tenant_id, + ToolConfigModel.is_active.is_(True) + ).first() + if t: + return str(t.id) + except (ValueError, AttributeError): + pass + if tool_name: + q = self.db.query(ToolConfigModel).filter( + ToolConfigModel.tenant_id == tenant_id, + ToolConfigModel.name == tool_name + ) + if ref.get("tool_type"): + q = q.filter(ToolConfigModel.tool_type == ref["tool_type"]) + t = q.first() + if t: + return str(t.id) + warnings.append(f"工具 '{tool_name}' 未匹配,已置空,请导入后手动配置") + else: + warnings.append(f"工具 '{tool_id}' 未匹配,已置空,请导入后手动配置") + return None def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]: if not ref: @@ -427,6 +549,61 @@ class AppDslService: result.append(entry) return result + def _resolve_workflow_nodes(self, nodes: list, tenant_id: uuid.UUID, workspace_id: uuid.UUID, warnings: list) -> list: + """解析工作流节点中的工具ID和知识库ID,匹配不到则清空配置""" + resolved_nodes = [] + for node in nodes: + node_type = node.get("type") + config = dict(node.get("config") or {}) + node_label = node.get("name") or node.get("id") + if node_type == NodeType.TOOL.value: + tool_id = config.get("tool_id") + if not tool_id: + # tool_id 本身就是空,直接置空不重复 warning + config["tool_id"] = None + config["tool_parameters"] = {} + else: + tool_ref = {} + if isinstance(tool_id, str) and len(tool_id) >= 36: + try: + uuid.UUID(tool_id) + tool_ref["id"] = tool_id + except ValueError: + tool_ref["name"] = tool_id + else: + tool_ref["name"] = tool_id + resolved_tool_id = self._resolve_tool(tool_ref, tenant_id, []) + if resolved_tool_id: + config["tool_id"] = resolved_tool_id + else: + warnings.append(f"[{node_label}] 工具 '{tool_id}' 未匹配,已置空,请导入后手动配置") + config["tool_id"] = None + config["tool_parameters"] = {} + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + knowledge_bases = config.get("knowledge_bases") or [] + resolved_kbs = [] + for kb in knowledge_bases: + kb_id = kb.get("kb_id") + if not kb_id: + continue + kb_ref = {} + if isinstance(kb_id, str) and len(kb_id) >= 36: + try: + uuid.UUID(kb_id) + kb_ref["id"] = kb_id + except ValueError: + kb_ref["name"] = kb_id + else: + kb_ref["name"] = kb_id + resolved_id = self._resolve_kb(kb_ref, workspace_id, []) + if resolved_id: + resolved_kbs.append({**kb, "kb_id": resolved_id}) + else: + warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置") + config["knowledge_bases"] = resolved_kbs + resolved_nodes.append({**node, "config": config}) + return resolved_nodes + def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]: if not kr: return kr diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 534ab8d0..64651189 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1452,6 +1452,32 @@ class AppService: logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)}) return self._create_default_agent_config(app_id) + def get_default_model_parameters( + self, + *, + app_id: uuid.UUID, + ) -> "ModelParameters": + """获取 Agent 默认模型参数(不修改数据库) + + Args: + app_id: 应用ID + + Returns: + ModelParameters: 默认模型参数 + """ + logger.info("获取 Agent 默认模型参数", extra={"app_id": str(app_id)}) + + app = self._get_app_or_404(app_id) + + if app.type != "agent": + raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED) + + from app.schemas.app_schema import ModelParameters + default_model_parameters = ModelParameters() + + logger.info("获取 Agent 默认模型参数成功", extra={"app_id": str(app_id)}) + return default_model_parameters + def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig: """创建默认的 Agent 配置模板(不保存到数据库) diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 6e9f3544..61744ec7 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -544,7 +544,7 @@ class ConversationService: api_key=api_key, base_url=api_base, is_omni=is_omni, - support_thinking="thinking" in (capability or []), + capability=capability, ), type=ModelType(model_type) ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 11011e6f..f6ebb191 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -599,6 +599,7 @@ class AgentRunService: tools=tools, deep_thinking=effective_params.get("deep_thinking", False), thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), + json_output=effective_params.get("json_output", False), capability=api_key_config.get("capability", []), ) @@ -855,6 +856,7 @@ class AgentRunService: streaming=True, deep_thinking=effective_params.get("deep_thinking", False), thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), + json_output=effective_params.get("json_output", False), capability=api_key_config.get("capability", []), ) @@ -1301,10 +1303,30 @@ class AgentRunService: "history_files": {} } if files: + from app.models.file_metadata_model import FileMetadata + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "file_type": f.file_type, + "name": name, + "size": size }) # 保存 history_files,包含 provider 和 is_omni 信息 diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index c226348e..9a215cd6 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -679,9 +679,9 @@ class EmotionAnalyticsService: # 查询用户的实体和标签 query = """ - MATCH (e:Entity) + MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id - RETURN e.name as name, e.type as type + RETURN e.name as name, e.entity_type as type ORDER BY e.created_at DESC LIMIT 20 """ diff --git a/api/app/services/implicit_memory_service.py b/api/app/services/implicit_memory_service.py index 4bd11deb..7a186f33 100644 --- a/api/app/services/implicit_memory_service.py +++ b/api/app/services/implicit_memory_service.py @@ -34,6 +34,7 @@ from app.schemas.implicit_memory_schema import ( UserMemorySummary, ) from app.schemas.memory_config_schema import MemoryConfig +from app.services.memory_base_service import MIN_MEMORY_SUMMARY_COUNT from sqlalchemy.orm import Session logger = logging.getLogger(__name__) @@ -379,12 +380,59 @@ class ImplicitMemoryService: raise + def _build_empty_profile(self) -> dict: + """构建 MemorySummary 不足时返回的固定空白画像数据""" + now_ms = int(datetime.utcnow().timestamp() * 1000) + insufficient = "Insufficient data for analysis" + + def _empty_dimension(name: str) -> dict: + return { + "evidence": [insufficient], + "reasoning": f"No clear evidence found for {name} dimension", + "percentage": 0.0, + "dimension_name": name, + "confidence_level": 20, + } + + def _empty_category(name: str) -> dict: + return { + "evidence": [insufficient], + "percentage": 25.0, + "category_name": name, + "trending_direction": None, + } + + return { + "habits": [], + "portrait": { + "aesthetic": _empty_dimension("aesthetic"), + "creativity": _empty_dimension("creativity"), + "literature": _empty_dimension("literature"), + "technology": _empty_dimension("technology"), + "historical_trends": None, + "analysis_timestamp": now_ms, + "total_summaries_analyzed": 0, + }, + "preferences": [], + "interest_areas": { + "art": _empty_category("art"), + "tech": _empty_category("tech"), + "music": _empty_category("music"), + "lifestyle": _empty_category("lifestyle"), + "analysis_timestamp": now_ms, + "total_summaries_analyzed": 0, + }, + } + async def generate_complete_profile( self, user_id: str ) -> dict: """生成完整的用户画像(包含所有4个模块) + 需要该用户的 MemorySummary 节点数量 >= 5 才会真正调用 LLM 生成画像, + 否则返回固定的空白画像数据。 + Args: user_id: 用户ID @@ -394,6 +442,16 @@ class ImplicitMemoryService: logger.info(f"生成完整用户画像: user={user_id}") try: + # 前置检查:查询该用户有效的 MemorySummary 节点数量(排除孤立节点) + from app.services.memory_base_service import MemoryBaseService + base_service = MemoryBaseService() + memory_summary_count = await base_service.get_valid_memory_summary_count(user_id) + logger.info(f"用户 MemorySummary 节点数量: {memory_summary_count} (user={user_id})") + + if memory_summary_count < MIN_MEMORY_SUMMARY_COUNT: + logger.info(f"MemorySummary 数量不足 {MIN_MEMORY_SUMMARY_COUNT}(当前 {memory_summary_count}),返回空白画像: user={user_id}") + return self._build_empty_profile() + # 并行调用4个分析方法 preferences, portrait, interest_areas, habits = await asyncio.gather( self.get_preference_tags(user_id=user_id), diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index bac02e96..94653db8 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -2,11 +2,14 @@ import uuid from sqlalchemy.orm import Session from app.models.user_model import User from app.models.knowledge_model import Knowledge +from app.models.workspace_model import Workspace +from app.models.models_model import ModelConfig from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate from app.repositories import knowledge_repository from app.core.logging_config import get_business_logger +from app.repositories.model_repository import ModelConfigRepository +from app.models.models_model import ModelType -# Obtain a dedicated logger for business logic business_logger = get_business_logger() @@ -60,13 +63,57 @@ def create_knowledge( db: Session, knowledge: KnowledgeCreate, current_user: User ) -> Knowledge: business_logger.info(f"Create a knowledge base: {knowledge.name}, creator: {current_user.username}") - + try: knowledge.created_by = current_user.id if knowledge.workspace_id is None: knowledge.workspace_id = current_user.current_workspace_id if knowledge.parent_id is None: knowledge.parent_id = knowledge.workspace_id + + workspace = db.query(Workspace).filter(Workspace.id == knowledge.workspace_id).first() + if not workspace: + raise Exception(f"Workspace {knowledge.workspace_id} not found") + + tenant_id = workspace.tenant_id + + if not knowledge.embedding_id: + embedding_models = ModelConfigRepository.get_by_type( + db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True + ) + if embedding_models: + knowledge.embedding_id = embedding_models[0].id + business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}") + + if not knowledge.reranker_id: + rerank_models = ModelConfigRepository.get_by_type( + db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True + ) + if rerank_models: + knowledge.reranker_id = rerank_models[0].id + business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}") + + if not knowledge.llm_id: + llm_models = ModelConfigRepository.get_by_type( + db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True + ) + if llm_models: + knowledge.llm_id = llm_models[0].id + business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}") + + if not knowledge.image2text_id: + image2text_models = db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.type.in_([ModelType.CHAT.value]), + ModelConfig.capability.contains(["vision"]), + ModelConfig.is_active == True, + ModelConfig.is_composite == False + ).order_by(ModelConfig.created_at.desc()).all() + if not image2text_models: + raise Exception("租户下没有可用的视觉模型,创建知识库失败") + knowledge.image2text_id = image2text_models[0].id + business_logger.debug(f"Auto-bind image2text model: {image2text_models[0].id}") + business_logger.debug(f"Start creating the knowledge base: {knowledge.name}") db_knowledge = knowledge_repository.create_knowledge( db=db, knowledge=knowledge diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 7087415e..bd90eee9 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -415,9 +415,11 @@ class LLMRouter: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - support_thinking="thinking" in (api_key_config.capability or []), - temperature=0.3, - max_tokens=500 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.3, + "max_tokens": 500 + } ) logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}") diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 206443bd..dfb3c2da 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -393,7 +393,7 @@ class MasterAgentRouter: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - support_thinking="thinking" in (api_key_config.capability or []), + capability=api_key_config.capability, extra_params = extra_params ) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index f62f526c..330b84ad 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -8,6 +8,8 @@ This service validates inputs and delegates to MemoryAgentService for core memor import uuid from typing import Any, Dict, Optional +from sqlalchemy.orm import Session + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.logging_config import get_logger @@ -15,7 +17,6 @@ from app.models.app_model import App from app.models.end_user_model import EndUser from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_agent_service import MemoryAgentService -from sqlalchemy.orm import Session logger = get_logger(__name__) @@ -124,7 +125,7 @@ class MemoryAPIService: except Exception as e: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") - async def write_memory( + def write_memory( self, workspace_id: uuid.UUID, end_user_id: str, @@ -133,27 +134,28 @@ class MemoryAPIService: storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: - """Write memory with validation. - + """Submit a memory write task via Celery. + Validates end_user exists and belongs to workspace, updates the end user's - memory_config_id, then delegates to MemoryAgentService.write_memory. - + memory_config_id, then dispatches write_message_task to Celery for async + processing with per-user fair locking. + Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as end_user_id) + end_user_id: End user identifier message: Message content to store config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID - + Returns: - Dict with status and end_user_id - + Dict with task_id, status, and end_user_id + Raises: ResourceNotFoundException: If end_user not found - BusinessException: If end_user not in authorized workspace or write fails + BusinessException: If validation fails """ - logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") + logger.info(f"Submitting memory write for end_user: {end_user_id}, workspace: {workspace_id}") # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) @@ -161,9 +163,120 @@ class MemoryAPIService: # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) + # Convert to message list format expected by write_message_task + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] + + from app.tasks import write_message_task + task = write_message_task.delay( + end_user_id, + messages, + config_id, + storage_type, + user_rag_memory_id or "", + ) + + logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}") + + return { + "task_id": task.id, + "status": "PENDING", + "end_user_id": end_user_id, + } + + def read_memory( + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Submit a memory read task via Celery. + + Validates end_user exists and belongs to workspace, updates the end user's + memory_config_id, then dispatches read_message_task to Celery for async processing. + + Args: + workspace_id: Workspace ID for resource validation + end_user_id: End user identifier + message: Query message + search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) + config_id: Memory configuration ID (required) + storage_type: Storage backend (neo4j or rag) + user_rag_memory_id: Optional RAG memory ID + + Returns: + Dict with task_id, status, and end_user_id + + Raises: + ResourceNotFoundException: If end_user not found + BusinessException: If validation fails + """ + logger.info(f"Submitting memory read for end_user: {end_user_id}, workspace: {workspace_id}") + + # Validate end_user exists and belongs to workspace + self.validate_end_user(end_user_id, workspace_id) + + # Update end user's memory_config_id + self._update_end_user_config(end_user_id, config_id) + + from app.tasks import read_message_task + task = read_message_task.delay( + end_user_id, + message, + [], # history + search_switch, + config_id, + storage_type, + user_rag_memory_id or "", + ) + + logger.info(f"Memory read task submitted: task_id={task.id}, end_user_id={end_user_id}") + + return { + "task_id": task.id, + "status": "PENDING", + "end_user_id": end_user_id, + } + + async def write_memory_sync( + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Write memory synchronously (inline, no Celery). + + Validates end_user, then calls MemoryAgentService.write_memory directly. + Blocks until the write completes. Use for cases where the caller needs + immediate confirmation. + + Args: + workspace_id: Workspace ID for resource validation + end_user_id: End user identifier + message: Message content to store + config_id: Memory configuration ID (required) + storage_type: Storage backend (neo4j or rag) + user_rag_memory_id: Optional RAG memory ID + + Returns: + Dict with status and end_user_id + + Raises: + ResourceNotFoundException: If end_user not found + BusinessException: If write fails + """ + logger.info(f"Writing memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}") + + self.validate_end_user(end_user_id, workspace_id) + self._update_end_user_config(end_user_id, config_id) + try: - # Delegate to MemoryAgentService - # Convert string message to list[dict] format expected by MemoryAgentService messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, @@ -174,11 +287,8 @@ class MemoryAPIService: user_rag_memory_id=user_rag_memory_id or "", ) - logger.info(f"Memory write successful for end_user: {end_user_id}") + logger.info(f"Memory write (sync) successful for end_user: {end_user_id}") - # result may be a string "success" or a dict with a "status" key - # Preserve the full dict so callers don't silently lose extra fields - # (e.g. error codes, metadata) returned by MemoryAgentService. if isinstance(result, dict): return { **result, @@ -192,20 +302,17 @@ class MemoryAPIService: except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") - raise BusinessException( - message=str(e), - code=BizCode.MEMORY_CONFIG_NOT_FOUND - ) + raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND) except BusinessException: raise except Exception as e: - logger.error(f"Memory write failed for end_user {end_user_id}: {e}") + logger.error(f"Memory write (sync) failed for end_user {end_user_id}: {e}") raise BusinessException( message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - async def read_memory( + async def read_memory_sync( self, workspace_id: uuid.UUID, end_user_id: str, @@ -215,37 +322,34 @@ class MemoryAPIService: storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: - """Read memory with validation. - - Validates end_user exists and belongs to workspace, updates the end user's - memory_config_id, then delegates to MemoryAgentService.read_memory. - + """Read memory synchronously (inline, no Celery). + + Validates end_user, then calls MemoryAgentService.read_memory directly. + Blocks until the read completes. Use for cases where the caller needs + the answer immediately. + Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as end_user_id) + end_user_id: End user identifier message: Query message search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID - + Returns: Dict with answer, intermediate_outputs, and end_user_id - + Raises: ResourceNotFoundException: If end_user not found - BusinessException: If end_user not in authorized workspace or read fails + BusinessException: If read fails """ - logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") + logger.info(f"Reading memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}") - # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - - # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) try: - # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( end_user_id=end_user_id, message=message, @@ -257,7 +361,7 @@ class MemoryAPIService: user_rag_memory_id=user_rag_memory_id or "" ) - logger.info(f"Memory read successful for end_user: {end_user_id}") + logger.info(f"Memory read (sync) successful for end_user: {end_user_id}") return { "answer": result.get("answer", ""), @@ -267,14 +371,11 @@ class MemoryAPIService: except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") - raise BusinessException( - message=str(e), - code=BizCode.MEMORY_CONFIG_NOT_FOUND - ) + raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND) except BusinessException: raise except Exception as e: - logger.error(f"Memory read failed for end_user {end_user_id}: {e}") + logger.error(f"Memory read (sync) failed for end_user {end_user_id}: {e}") raise BusinessException( message=f"Memory read failed: {str(e)}", code=BizCode.MEMORY_READ_FAILED diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index bc647752..e615af8b 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -265,12 +265,50 @@ async def Translation_English(modid, text, fields=None): # 其他类型(数字、布尔值、None等):原样返回 else: return text +# 隐性记忆画像生成所需的最低 MemorySummary 节点数量 +MIN_MEMORY_SUMMARY_COUNT = 5 + + class MemoryBaseService: """记忆服务基类,提供共享的辅助方法""" def __init__(self): self.neo4j_connector = Neo4jConnector() + async def get_valid_memory_summary_count( + self, + end_user_id: str + ) -> int: + """获取用户有效的 MemorySummary 节点数量(排除孤立节点)。 + + 只统计存在 DERIVED_FROM_STATEMENT 关系的 MemorySummary 节点。 + + Args: + end_user_id: 终端用户ID + + Returns: + 有效 MemorySummary 节点数量 + """ + try: + query = """ + MATCH (n:MemorySummary)-[:DERIVED_FROM_STATEMENT]->(:Statement) + WHERE n.end_user_id = $end_user_id + RETURN count(DISTINCT n) as count + """ + result = await self.neo4j_connector.execute_query( + query, end_user_id=end_user_id + ) + count = result[0]["count"] if result and len(result) > 0 else 0 + logger.debug( + f"有效 MemorySummary 节点数量: {count} (end_user_id={end_user_id})" + ) + return count + except Exception as e: + logger.error( + f"获取有效 MemorySummary 数量失败: {str(e)}", exc_info=True + ) + return 0 + @staticmethod def parse_timestamp(timestamp_value) -> Optional[int]: """ diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 7d6d1092..8fa9c9bf 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -233,7 +233,7 @@ class MemoryPerceptualService: api_key=model_config.api_key, base_url=model_config.api_base, is_omni=model_config.is_omni, - support_thinking="thinking" in (model_config.capability or []), + capability=model_config.capability, ) ) return llm, model_config diff --git a/api/app/services/model_parameter_merger.py b/api/app/services/model_parameter_merger.py index 4be83851..6911a9d5 100644 --- a/api/app/services/model_parameter_merger.py +++ b/api/app/services/model_parameter_merger.py @@ -47,7 +47,8 @@ class ModelParameterMerger: "n": 1, "stop": None, "deep_thinking": False, - "thinking_budget_tokens": None + "thinking_budget_tokens": None, + "json_output": False } # 合并参数:默认值 -> 模型配置 -> Agent 配置 diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 4cbb3509..8807020b 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -125,9 +125,11 @@ class ModelConfigService: api_key=api_key, base_url=api_base, is_omni=is_omni, - support_thinking="thinking" in (capability or []), - temperature=0.7, - max_tokens=100 + capability=capability, + extra_params={ + "temperature": 0.7, + "max_tokens": 100 + } ) # 根据模型类型选择不同的验证方式 @@ -729,10 +731,21 @@ class ModelApiKeyService: @staticmethod def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool: """删除API Key""" - if not ModelApiKeyRepository.get_by_id(db, api_key_id): + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if not api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) + model_config_ids = [mc.id for mc in api_key.model_configs] + success = ModelApiKeyRepository.delete(db, api_key_id) + + for model_config_id in model_config_ids: + model_config = ModelConfigRepository.get_by_id(db, model_config_id) + if model_config: + has_active_key = any(key.is_active for key in model_config.api_keys) + if not has_active_key and model_config.is_active: + model_config.is_active = False + db.commit() return success diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 216aeb6e..d30dc822 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -2616,9 +2616,11 @@ class MultiAgentOrchestrator: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - support_thinking="thinking" in (api_key_config.capability or []), - temperature=0.7, # 整合任务使用中等温度 - max_tokens=2000 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.7, # 整合任务使用中等温度 + "max_tokens": 2000 + } ) # 创建 LLM 实例 @@ -2795,10 +2797,12 @@ class MultiAgentOrchestrator: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - support_thinking="thinking" in (api_key_config.capability or []), - temperature=0.7, - max_tokens=2000, - extra_params={"streaming": True} # 启用流式输出 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.7, + "max_tokens": 2000, + "streaming": True # 启用流式输出 + } ) # 创建 LLM 实例 diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index fde8c4f9..1686a164 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -186,7 +186,7 @@ class PromptOptimizerService: api_key=api_config.api_key, base_url=api_config.api_base, is_omni=api_config.is_omni, - support_thinking="thinking" in (api_config.capability or []), + capability=api_config.capability, ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') @@ -227,10 +227,20 @@ class PromptOptimizerService: content = getattr(chunk, "content", chunk) if not content: continue - buffer += content + if isinstance(content, str): + buffer += content + elif isinstance(content, list): + for _ in content: + buffer += _["text"] + else: + logger.error(f"Unsupported content type - {content}") + raise Exception("Unsupported content type") cache = buffer[:-20] + last_idx = 19 + while cache and cache[-1] == '\\' and last_idx > 0: + cache = buffer[:-last_idx] + last_idx -= 1 - # 尝试找到 "prompt": " 开始位置 if prompt_finished: continue @@ -272,7 +282,7 @@ class PromptOptimizerService: def parser_prompt_variables(prompt: str): try: pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}' - matches = re.findall(pattern, prompt) + matches = re.findall(pattern, str(prompt)) variables = list(set(matches)) return variables except Exception as e: diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index b1e40a2d..37956d77 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -250,7 +250,8 @@ class SharedChatService: tools=tools, deep_thinking=model_parameters.get("deep_thinking", False), thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), - capability=api_key_obj.capability or [], + json_output=model_parameters.get("json_output", False), + capability=api_key_obj.capability, ) # 加载历史消息 @@ -455,6 +456,7 @@ class SharedChatService: streaming=True, deep_thinking=model_parameters.get("deep_thinking", False), thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), capability=api_key_obj.capability or [], ) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 79f8fa05..4d120d8c 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.core.logging_config import get_logger +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository @@ -21,7 +22,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping -from app.services.memory_base_service import MemoryBaseService +from app.services.memory_base_service import MemoryBaseService, MIN_MEMORY_SUMMARY_COUNT from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -400,12 +401,21 @@ class UserMemoryService: # 构建响应数据(转换时间为毫秒时间戳) # 将 meta_data 中的 profile、knowledge_tags、behavioral_hints 平铺到顶层 meta = end_user_info_record.meta_data or {} + + # profile 列表字段截断:只返回前 MAX_PROFILE_LIST_SIZE 条(按时间从新到旧) + MAX_PROFILE_LIST_SIZE = 5 + profile = meta.get("profile") + if isinstance(profile, dict): + for key in ("role", "domain", "expertise", "interests"): + if isinstance(profile.get(key), list): + profile[key] = profile[key][:MAX_PROFILE_LIST_SIZE] + response_data = { "end_user_info_id": str(end_user_info_record.id), "end_user_id": str(end_user_info_record.end_user_id), "other_name": end_user_info_record.other_name, "aliases": end_user_info_record.aliases, - "profile": meta.get("profile"), + "profile": profile, "knowledge_tags": meta.get("knowledge_tags"), "behavioral_hints": meta.get("behavioral_hints"), "created_at": datetime_to_timestamp(end_user_info_record.created_at), @@ -477,7 +487,7 @@ class UserMemoryService: allowed_fields = {'other_name', 'aliases', 'meta_data'} # 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中 - _user_placeholder_names = {'用户', '我', 'User', 'I'} + _user_placeholder_names = _USER_PLACEHOLDER_NAMES # 过滤 other_name:不允许设置为占位名称 if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names: @@ -1504,7 +1514,7 @@ async def analytics_memory_types( 2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取) 3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量 4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) - 5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一 + 5. 隐性记忆 (IMPLICIT_MEMORY) = MemorySummary 节点数量(需 >= MIN_MEMORY_SUMMARY_COUNT 才显示,否则为 0) 6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) 7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) 8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) @@ -1561,23 +1571,15 @@ async def analytics_memory_types( logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}") work_count = 0 - # 获取隐性记忆数量(基于 Statement 节点数量的三分之一) + # 获取隐性记忆数量(基于有关联关系的 MemorySummary 节点数量,需 >= MIN_MEMORY_SUMMARY_COUNT 才计入) implicit_count = 0 if end_user_id: try: - # 查询 Statement 节点数量 - query = """ - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN count(n) as count - """ - result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) - statement_count = result[0]["count"] if result and len(result) > 0 else 0 - # 取三分之一作为隐性记忆数量 - implicit_count = round(statement_count / 3) - logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})") + memory_summary_count = await base_service.get_valid_memory_summary_count(end_user_id) + implicit_count = memory_summary_count if memory_summary_count >= MIN_MEMORY_SUMMARY_COUNT else 0 + logger.debug(f"隐性记忆数量(有效MemorySummary节点数): {implicit_count} (有效MemorySummary总数={memory_summary_count}, end_user_id={end_user_id})") except Exception as e: - logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}") + logger.warning(f"获取MemorySummary数量失败,隐性记忆数量设为0: {str(e)}") implicit_count = 0 # 原有的基于行为习惯的统计方式(已注释) @@ -1643,7 +1645,7 @@ async def analytics_memory_types( "WORKING_MEMORY": work_count, # 工作记忆(基于会话数量) "SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量) "EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆) - "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3) + "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(MemorySummary节点数,需>=MIN_MEMORY_SUMMARY_COUNT) "EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计) "EPISODIC_MEMORY": episodic_count, # 情景记忆 "FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值) diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index 3122d282..43a58c5f 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -285,7 +285,7 @@ def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: Use try: # 查找用户 business_logger.debug(f"查找待激活用户: {user_id_to_activate}") - db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate) + db_user = user_repository.get_user_by_id_regardless_active(db, user_id=user_id_to_activate) if not db_user: business_logger.warning(f"用户不存在: {user_id_to_activate}") raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index b771c639..0d282d78 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -957,7 +957,10 @@ class WorkflowService: for file in message["content"]: human_meta["files"].append({ "type": file.get("type"), - "url": file.get("url") + "url": file.get("url"), + "file_type": file.get("origin_file_type"), + "name": file.get("name"), + "size": file.get("size") }) if message["role"] == "assistant": assistant_message = message["content"] diff --git a/api/app/tasks.py b/api/app/tasks.py index 5a71066a..8bbbdc6e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -455,7 +455,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") - return f"build knowledge graph failed: knowledge not found" + return "build knowledge graph failed: knowledge not found" if not (db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)): @@ -538,7 +538,7 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str): db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first() if db_document is None or db_knowledge is None: logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found") - return f"build_graphrag_for_document failed: record not found" + return "build_graphrag_for_document failed: record not found" graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) with_resolution = graphrag_conf.get("resolution", False) @@ -617,7 +617,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[SyncKB] knowledge={kb_id} not found") - return f"sync knowledge failed: knowledge not found" + return "sync knowledge failed: knowledge not found" # 1. get vector_service vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -3102,29 +3102,11 @@ def extract_user_metadata_task( logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}") return {"status": "SUCCESS", "result": "no_metadata_extracted"} - user_metadata, aliases_to_add, aliases_to_remove = extract_result - logger.info(f"[CELERY METADATA] LLM 别名新增: {aliases_to_add}, 移除: {aliases_to_remove}") - - # 4. 清洗元数据、覆盖写入元数据和别名 - def clean_metadata(raw: dict) -> dict: - """递归移除空字符串、空列表、空字典。""" - result = {} - for k, v in raw.items(): - if v == "" or v == []: - continue - if isinstance(v, dict): - cleaned = clean_metadata(v) - if cleaned: - result[k] = cleaned - else: - result[k] = v - return result - - raw_dict = user_metadata.model_dump(exclude_none=True) if user_metadata else {} - logger.info(f"[CELERY METADATA] LLM 输出完整元数据: {json.dumps(raw_dict, ensure_ascii=False)}") - - cleaned = clean_metadata(raw_dict) if raw_dict else {} - logger.info(f"[CELERY METADATA] 清洗后元数据: {json.dumps(cleaned, ensure_ascii=False)}") + metadata_changes, aliases_to_add, aliases_to_remove = extract_result + logger.info( + f"[CELERY METADATA] LLM 元数据变更: {[c.model_dump() for c in metadata_changes]}, " + f"别名新增: {aliases_to_add}, 移除: {aliases_to_remove}" + ) from datetime import datetime as dt, timezone as tz now = dt.now(tz.utc).isoformat() @@ -3152,15 +3134,49 @@ def extract_user_metadata_task( end_user = EndUserRepository(db).get_by_id(end_user_uuid) if info: - # 元数据覆盖写入 - if cleaned: - existing_meta = info.meta_data if info.meta_data else {} + # 4. 元数据增量更新(按 LLM 输出的变更操作逐条执行,所有字段均为列表类型) + if metadata_changes: + # 深拷贝,确保 SQLAlchemy 能检测到变更 + import copy + existing_meta = copy.deepcopy(info.meta_data) if info.meta_data else {} updated_at = dict(existing_meta.get("_updated_at", {})) - _update_timestamps(existing_meta, cleaned, updated_at, now) - final = dict(cleaned) - final["_updated_at"] = updated_at - info.meta_data = final - logger.info("[CELERY METADATA] 覆盖写入元数据") + + for change in metadata_changes: + field_path = change.field_path + action = change.action + value = change.value + + if not value or not value.strip(): + continue + + # 定位到目标字段的父级节点 + parts = field_path.split(".") + target = existing_meta + for part in parts[:-1]: + target = target.setdefault(part, {}) + leaf = parts[-1] + + current_list = target.get(leaf, []) + + if action == "set": + if value not in current_list: + # 新值插入列表头部,保证按时间从新到旧排序 + current_list.insert(0, value) + target[leaf] = current_list + logger.info(f"[CELERY METADATA] set {field_path} = {value}") + + elif action == "remove": + if value in current_list: + current_list.remove(value) + target[leaf] = current_list + logger.info(f"[CELERY METADATA] remove {value} from {field_path}") + + updated_at[field_path] = now + + existing_meta["_updated_at"] = updated_at + # 赋值深拷贝后的新对象,SQLAlchemy 会检测到字段变更并写入 + info.meta_data = existing_meta + logger.info(f"[CELERY METADATA] 增量更新元数据完成: {json.dumps(existing_meta, ensure_ascii=False)}") # 别名增量增删:(已有 - remove) + add old_aliases = info.aliases if info.aliases else [] @@ -3196,12 +3212,28 @@ def extract_user_metadata_task( from app.models.end_user_info_model import EndUserInfo initial_aliases = filtered_add # 新记录只有 add,没有 remove first_alias = initial_aliases[0] if initial_aliases else "" - if first_alias or cleaned: + + # 从变更操作构建初始元数据(所有字段均为列表类型) + initial_meta = {} + for change in metadata_changes: + if change.action == "set" and change.value is not None and change.value.strip(): + parts = change.field_path.split(".") + target = initial_meta + for part in parts[:-1]: + target = target.setdefault(part, {}) + leaf = parts[-1] + current_list = target.get(leaf, []) + if change.value not in current_list: + # 新值插入列表头部,保证按时间从新到旧排序 + current_list.insert(0, change.value) + target[leaf] = current_list + + if first_alias or initial_meta: new_info = EndUserInfo( end_user_id=end_user_uuid, other_name=first_alias or "", aliases=initial_aliases, - meta_data=cleaned if cleaned else None, + meta_data=initial_meta if initial_meta else None, ) db.add(new_info) if end_user and first_alias and ( diff --git a/api/app/version_info.json b/api/app/version_info.json index d07035e2..a094b64c 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,40 @@ { + "v0.3.0": { + "introduction": { + "codeName": "破晓", + "releaseDate": "2026-4-15", + "upgradePosition": "🐻 全面升级应用工作流、记忆智能与系统稳健性,引入版本化API、多模态记忆感知及大量工作流增强,打造更可靠、精准的 MemoryBear", + "coreUpgrades": [ + "1. 应用与API增强
* 版本化API调用支持:对外服务API支持指定版本调用
* 工作流检查清单:新增结构化验证步骤
* 深度思考参数精准控制:仅向支持深度推理的模型发送思考参数
* 提示器模型返回优化:优化提示器模型响应处理", + "2. 记忆智能 🧠
* 多模态记忆感知Agent:支持多模态记忆读取与写入
* OpenClaw内置工具:新增内置工具扩展Agent工具集", + "3. 用户体验 🎨
* 流式渲染稳定性优化:解决LLM流式输出页面抖动问题
* 记忆中枢更名:「记忆相关」更名为「记忆中枢」", + "4. 工作流改进 ⚙️
* 三级变量模板转换:支持三级变量解析
* VL模型Token统计:修复模型组合中VL模型Token未统计问题
* 导入工作流功能特性同步:正确同步开场白、引用等属性
* 会话变量名称唯一性校验:防止变量名冲突
* 文件类型提取修复:正确提取file.type信息
* 条件分支显示修复:值为0或会话变量时正确渲染
* Object/Array校验规则:防止JSON序列化错误
* HTTP请求Body字段修正:body字段从name改为key", + "5. 知识库 📚
* Embedding Token截断安全边界:统一添加8000 token截断,优化Excel独立chunk处理", + "6. 稳健性与缺陷修复 🔧
* 原子性更新与批量访问失败修复
* 对话别名提取错误修复
* 工作流别名提取修正(区分用户和AI回复)
* RAG记忆分页数据修复
* 隐式记忆详情显示修复
* 向量查询驱动关闭异常修复
* 用户管理启停异常修复
* 模型列表筛选不一致修复", + "
", + "v0.3.0 标志着 MemoryBear 向生产成熟度迈出坚实一步。后续版本将持续深化工作流表达力、记忆检索精度和跨模态理解能力,强化复杂Agent编排支持,稳固大规模生产部署基础。", + "
", + "MemoryBear — 破晓 🐻✨" + ] + }, + "introduction_en": { + "codeName": "PoXiao", + "releaseDate": "2026-4-15", + "upgradePosition": "🐻 Comprehensive upgrades across application workflows, memory intelligence, and system robustness — introducing versioned APIs, multimodal memory perception, and extensive workflow enhancements for a more reliable MemoryBear", + "coreUpgrades": [ + "1. Application & API Enhancements
* Versioned API Support: External APIs now support version-specific calls
* Workflow Checklist: Structured validation steps before deployment
* Deep Thinking Parameter Control: Only send thinking params to supported models
* Prompt Optimizer Return Optimization: Improved prompt optimizer response handling", + "2. Memory Intelligence 🧠
* Multimodal Memory Perception Agent: Read/write multimodal memory
* OpenClaw Built-in Tool: New built-in tool for agent operations", + "3. User Experience 🎨
* Streaming Render Stabilization: Eliminated page jitter during LLM output
* Memory Hub Renaming: Renamed to better reflect central memory role", + "4. Workflow Improvements ⚙️
* Three-Level Variable Template Conversion: Support for three-level variable resolution
* VL Model Token Tracking: Fixed token tracking for VL models in model groups
* Imported Workflow Feature Sync: Properly sync opening messages, citations, etc.
* Session Variable Name Uniqueness: Prevent variable name conflicts
* File Type Extraction Fix: Correctly extract file.type information
* Condition Branch Display Fix: Correct rendering for value 0 or session variables
* Object/Array Validation Rules: Prevent JSON serialization save errors
* HTTP Request Body Key Fix: Body field uses key instead of name", + "5. Knowledge Base 📚
* Embedding Token Truncation Safety: Unified 8000-token boundary, optimized Excel chunk processing", + "6. Robustness & Bug Fixes 🔧
* Atomic update & batch access failure fixes
* Conversation alias extraction fix
* Workflow alias extraction correction (user vs AI distinction)
* RAG memory pagination fix
* Implicit memory detail display fix
* Vector query driver closed exception fix
* User management enable/disable fix
* Model list filter inconsistency fix", + "
", + "v0.3.0 marks a meaningful step toward production maturity for MemoryBear. Upcoming releases will deepen workflow expressiveness, memory retrieval precision, and cross-modal understanding while strengthening complex agent orchestration and large-scale deployment foundations.", + "
", + "MemoryBear — Daybreak 🐻✨" + ] + } + }, "v0.2.10": { "introduction": { "codeName": "炼剑", diff --git a/web/package.json b/web/package.json index b41ab9b5..1f1fc397 100644 --- a/web/package.json +++ b/web/package.json @@ -93,7 +93,8 @@ "typescript-eslint": "^8.45.0", "unplugin-auto-import": "^20.2.0", "unplugin-vue-components": "^29.1.0", - "vite": "npm:rolldown-vite@7.1.14" + "vite": "npm:rolldown-vite@7.1.14", + "vite-plugin-svgr": "^5.2.0" }, "overrides": { "vite": "npm:rolldown-vite@7.1.14" diff --git a/web/src/App.tsx b/web/src/App.tsx index a10f9409..1af38372 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -16,7 +16,7 @@ import { ConfigProvider, App as AntdApp } from 'antd'; -import { useTranslation } from 'react-i18next'; +import i18n from 'i18next'; import { lightTheme } from './styles/antdThemeConfig.ts' import router from './routes'; @@ -29,11 +29,58 @@ import 'dayjs/plugin/utc' import { cookieUtils } from './utils/request'; import { useUser } from '@/store/user'; +import menuJson from '@/store/menu.json'; + +type MenuEntry = { path: string; i18nKey: string }; + +function flattenMenuEntries(list: any[]): MenuEntry[] { + const result: MenuEntry[] = []; + for (const item of list) { + if (item.path && item.i18nKey && item.type !== 'group') result.push({ path: item.path, i18nKey: item.i18nKey }); + if (item.subs?.length) result.push(...flattenMenuEntries(item.subs)); + } + return result; +} + +const menuEntries: MenuEntry[] = flattenMenuEntries([...menuJson.manage, ...menuJson.space]); + +function pathMatches(pattern: string, path: string): boolean { + if (pattern === path) return true; + if (pattern.includes(':')) { + return new RegExp('^' + pattern.replace(/:[\w-]+/g, '[^/]+') + '$').test(path); + } + return false; +} + +function getPageTitle(pathname: string): string { + const appName = i18n.t('memoryBear'); + const entry = menuEntries.find(e => pathMatches(e.path, pathname)); + if (!entry) return appName; + return `${i18n.t(entry.i18nKey)} - ${appName}`; +} + +const SKIP_TITLE_PATTERNS = [ + '/user-memory/detail/:id/:type', + '/forgetting-engine/:id', + '/memory-extraction-engine/:id', + '/emotion-engine/:id', + '/reflection-engine/:id', +]; + + + function App() { - const { t } = useTranslation(); const { locale, language, timeZone } = useI18n() const { checkJump } = useUser(); + useEffect(() => { + const unsubscribe = router.subscribe(({ location }) => { + if (SKIP_TITLE_PATTERNS.some(p => pathMatches(p, location.pathname))) return; + document.title = getPageTitle(location.pathname); + }); + return () => unsubscribe(); + }, []) + useEffect(() => { const authToken = cookieUtils.get('authToken') if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump') && !window.location.hash.includes('#/invite-register')) { @@ -44,7 +91,9 @@ function App() { }, []) useEffect(() => { - document.title = t('memoryBear') + if (!SKIP_TITLE_PATTERNS.some(p => pathMatches(p, router.state.location.pathname))) { + document.title = getPageTitle(router.state.location.pathname) + } dayjs.locale(language) localStorage.setItem('language', language) }, [language]) diff --git a/web/src/api/application.ts b/web/src/api/application.ts index a5730289..5614232e 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -174,4 +174,8 @@ export const getAppLogsUrl = (app_id: string) => `/apps/${app_id}/logs` // Get full conversation message history export const getAppLogDetail = (app_id: string, conversation_id: string) => { return request.get(`/apps/${app_id}/logs/${conversation_id}`) +} +// Reset agent model config to default +export const resetAppModelConfig = (app_id: string) => { + return request.get(`/apps/${app_id}/model/parameters/default`) } \ No newline at end of file diff --git a/web/src/api/package.ts b/web/src/api/package.ts new file mode 100644 index 00000000..f9cd2f74 --- /dev/null +++ b/web/src/api/package.ts @@ -0,0 +1,8 @@ +import { request } from '@/utils/request' + +import type { Package } from '@/views/Package/types' +// 套餐列表 +export const getPackageListUrl = `/package-plans` +export const getPackageList = (query?: { category?: Package['category']; status?: boolean; }) => { + return request.get(getPackageListUrl, query) +} \ No newline at end of file diff --git a/web/src/api/user.ts b/web/src/api/user.ts index 72a3ad73..0752f019 100644 --- a/web/src/api/user.ts +++ b/web/src/api/user.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:23 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-25 11:17:44 + * @Last Modified time: 2026-04-14 18:36:01 */ import { request } from '@/utils/request' import type { CreateModalData, ChangeEmailModalForm } from '@/views/UserManagement/types' @@ -56,4 +56,9 @@ export const sendEmailCode = (data: { email: string }) => { // Verify code and change email export const changeEmail = (data: ChangeEmailModalForm) => { return request.put('/users/change-email', data) +} + +// 获取租户套餐信息 +export const getTenantSubscription = () => { + return request.get('/tenant/subscription') } \ No newline at end of file diff --git a/web/src/assets/images/application/export.svg b/web/src/assets/images/application/export.svg new file mode 100644 index 00000000..c07a346d --- /dev/null +++ b/web/src/assets/images/application/export.svg @@ -0,0 +1,17 @@ + + + 导出 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/application/import.svg b/web/src/assets/images/application/import.svg new file mode 100644 index 00000000..6dde8f3c --- /dev/null +++ b/web/src/assets/images/application/import.svg @@ -0,0 +1,17 @@ + + + 导入 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/common/close_grey.svg b/web/src/assets/images/common/close_grey.svg new file mode 100644 index 00000000..6797b67f --- /dev/null +++ b/web/src/assets/images/common/close_grey.svg @@ -0,0 +1,15 @@ + + + 关闭 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/index/arrow_right_dark.svg b/web/src/assets/images/index/arrow_right_dark.svg new file mode 100644 index 00000000..b2742d11 --- /dev/null +++ b/web/src/assets/images/index/arrow_right_dark.svg @@ -0,0 +1,16 @@ + + + 编组 5 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/logout.svg b/web/src/assets/images/logout.svg deleted file mode 100644 index eedaccc4..00000000 --- a/web/src/assets/images/logout.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - 退出 - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/logout_grey.svg b/web/src/assets/images/logout_grey.svg new file mode 100644 index 00000000..b9b566c3 --- /dev/null +++ b/web/src/assets/images/logout_grey.svg @@ -0,0 +1,19 @@ + + + 退出 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/logout_hover.svg b/web/src/assets/images/logout_hover.svg deleted file mode 100644 index d77ab292..00000000 --- a/web/src/assets/images/logout_hover.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - 退出 - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/menuNew/package_bg.png b/web/src/assets/images/menuNew/package_bg.png new file mode 100644 index 00000000..cbed6f7a Binary files /dev/null and b/web/src/assets/images/menuNew/package_bg.png differ diff --git a/web/src/assets/images/package/api_ops.svg b/web/src/assets/images/package/api_ops.svg new file mode 100644 index 00000000..47512f69 --- /dev/null +++ b/web/src/assets/images/package/api_ops.svg @@ -0,0 +1,17 @@ + + + 频次 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/app.svg b/web/src/assets/images/package/app.svg new file mode 100644 index 00000000..699e5d87 --- /dev/null +++ b/web/src/assets/images/package/app.svg @@ -0,0 +1,17 @@ + + + 应用 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/arrow.svg b/web/src/assets/images/package/arrow.svg new file mode 100644 index 00000000..675d3dee --- /dev/null +++ b/web/src/assets/images/package/arrow.svg @@ -0,0 +1,13 @@ + + + 编组 49 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/disable.svg b/web/src/assets/images/package/disable.svg new file mode 100644 index 00000000..7e23d26f --- /dev/null +++ b/web/src/assets/images/package/disable.svg @@ -0,0 +1,18 @@ + + + 编组 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/enable.svg b/web/src/assets/images/package/enable.svg new file mode 100644 index 00000000..3df8f472 --- /dev/null +++ b/web/src/assets/images/package/enable.svg @@ -0,0 +1,18 @@ + + + 编组 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/end_user.svg b/web/src/assets/images/package/end_user.svg new file mode 100644 index 00000000..e6109b18 --- /dev/null +++ b/web/src/assets/images/package/end_user.svg @@ -0,0 +1,19 @@ + + + 终端 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/knowledge.svg b/web/src/assets/images/package/knowledge.svg new file mode 100644 index 00000000..3858efe1 --- /dev/null +++ b/web/src/assets/images/package/knowledge.svg @@ -0,0 +1,17 @@ + + + 知识库容量 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/memory_config.svg b/web/src/assets/images/package/memory_config.svg new file mode 100644 index 00000000..a1b38c5e --- /dev/null +++ b/web/src/assets/images/package/memory_config.svg @@ -0,0 +1,20 @@ + + + 记忆引擎 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/model.svg b/web/src/assets/images/package/model.svg new file mode 100644 index 00000000..23483fc0 --- /dev/null +++ b/web/src/assets/images/package/model.svg @@ -0,0 +1,17 @@ + + + 模型 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/ontology.svg b/web/src/assets/images/package/ontology.svg new file mode 100644 index 00000000..ff94829b --- /dev/null +++ b/web/src/assets/images/package/ontology.svg @@ -0,0 +1,17 @@ + + + 本体工程 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/skill.svg b/web/src/assets/images/package/skill.svg new file mode 100644 index 00000000..195248d9 --- /dev/null +++ b/web/src/assets/images/package/skill.svg @@ -0,0 +1,17 @@ + + + 技能 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/sla.svg b/web/src/assets/images/package/sla.svg new file mode 100644 index 00000000..10e4ce10 --- /dev/null +++ b/web/src/assets/images/package/sla.svg @@ -0,0 +1,19 @@ + + + SLA + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/space.svg b/web/src/assets/images/package/space.svg new file mode 100644 index 00000000..6775932d --- /dev/null +++ b/web/src/assets/images/package/space.svg @@ -0,0 +1,17 @@ + + + 空间 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/package/technical_support.svg b/web/src/assets/images/package/technical_support.svg new file mode 100644 index 00000000..d9b4251e --- /dev/null +++ b/web/src/assets/images/package/technical_support.svg @@ -0,0 +1,17 @@ + + + 合规 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 5c722e45..f28b5dce 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:17 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-10 18:46:57 + * @Last Modified time: 2026-04-14 10:13:56 */ import { type FC, useRef, useEffect, useState } from 'react' import clsx from 'clsx' @@ -174,6 +174,7 @@ const ChatContent: FC = ({ ) } + const documentType = (file.file_type || file.type)?.split('/') return ( = ({ >
{file.name}
-
{file.type?.split('/')[file.type?.split('/').length - 1]} · {file.size}
+
{documentType?.[documentType.length - 1]} · {file.size}
) diff --git a/web/src/components/CodeMirrorEditor/index.tsx b/web/src/components/CodeMirrorEditor/index.tsx index ec2a6780..23729dcc 100644 --- a/web/src/components/CodeMirrorEditor/index.tsx +++ b/web/src/components/CodeMirrorEditor/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-04 17:20:52 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-04 17:20:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-04-16 11:46:39 */ import { useEffect, useRef, useMemo } from 'react'; import { EditorView, basicSetup } from 'codemirror'; @@ -35,7 +35,7 @@ interface CodeMirrorEditorProps { height?: string; size?: 'default' | 'small'; placeholder?: string; - variant?: 'outlined' | 'borderless'; + variant?: 'outlined' | 'borderless' | 'filled'; } /** @@ -156,7 +156,7 @@ const CodeMirrorEditor = ({
); }; diff --git a/web/src/components/Header/index.module.css b/web/src/components/Header/index.module.css index d39c91ec..525a2432 100644 --- a/web/src/components/Header/index.module.css +++ b/web/src/components/Header/index.module.css @@ -12,6 +12,14 @@ font-weight: 500; font-style: normal; } +.breadcrumbTitle { + display: inline-block; + max-width: 200px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + vertical-align: bottom; +} .header :global(.ant-breadcrumb) { line-height: 31px; } diff --git a/web/src/components/Header/index.tsx b/web/src/components/Header/index.tsx index 49988223..de87dcfc 100644 --- a/web/src/components/Header/index.tsx +++ b/web/src/components/Header/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-02 15:07:49 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-07 12:18:58 + * @Last Modified time: 2026-04-16 10:31:21 */ /** * AppHeader Component @@ -14,7 +14,7 @@ */ import { type FC, useRef, useState } from 'react'; -import { Layout, Dropdown, Breadcrumb, Flex } from 'antd'; +import { Layout, Dropdown, Breadcrumb, Flex, Tooltip } from 'antd'; import type { MenuProps, BreadcrumbProps } from 'antd'; import { useTranslation } from 'react-i18next'; import { useLocation } from 'react-router-dom'; @@ -31,7 +31,7 @@ const { Header } = Layout; /** * @param source - Breadcrumb source type ('space' or 'manage'), defaults to 'manage' */ -const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { +const AppHeader: FC<{ source?: 'space' | 'manage'; }> = ({ source = 'manage' }) => { const { t } = useTranslation(); const location = useLocation(); const settingModalRef = useRef(null) @@ -39,7 +39,7 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { const { user, logout } = useUser(); const { allBreadcrumbs } = useMenu(); - + /** * Dynamically select breadcrumb source based on current route * - Knowledge base list: uses 'space' breadcrumb @@ -48,24 +48,24 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { */ const getBreadcrumbSource = () => { const pathname = location.pathname; - + // Knowledge base list page uses default space breadcrumb if (pathname === '/knowledge-base') { return 'space'; } - + // Knowledge base detail pages use independent breadcrumb if (pathname.includes('/knowledge-base/') && pathname !== '/knowledge-base') { return 'space-detail'; } - + // Other pages use the passed source return source; }; - + const breadcrumbSource = getBreadcrumbSource(); const breadcrumbs = allBreadcrumbs[breadcrumbSource] || []; - + /** Handle user logout */ const handleLogout = () => { @@ -76,9 +76,11 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { const userMenuItems: MenuProps['items'] = [ { key: '1', - icon: - {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(0, 2) : user.username?.[0]} - , + icon: user.username + ? + {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(-2) : user.username[0]} + + : null, label: (<>
{user.username}
{user.email}
@@ -127,7 +129,7 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { onClick: handleLogout, }, ]; - + /** * Format breadcrumb items with proper titles, paths, and click handlers * - Translates i18n keys to display text @@ -135,32 +137,34 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { * - Disables navigation for the last breadcrumb item */ const formatBreadcrumbNames = () => { - return breadcrumbs.filter(item => item.type !== 'group').map((menu, index) => { + const filtered = breadcrumbs.filter(item => item.type !== 'group'); + return filtered.map((menu, index) => { + const label = menu.i18nKey ? t(menu.i18nKey) : menu.label; + const isLast = index === filtered.length - 1; const item: any = { - title: menu.i18nKey ? t(menu.i18nKey) : menu.label, + title: ( + + {label} + + ), }; - - // If it's the last item, don't set path - if (index === breadcrumbs.length - 1) { - return item; + + if (!isLast) { + if ((menu as any).onClick) { + item.onClick = (e: React.MouseEvent) => { + e.preventDefault(); + (menu as any).onClick(e); + }; + item.href = '#'; + } else if (menu.path && menu.path !== '#') { + item.path = menu.path; + } } - - // If has custom onClick, use onClick and set href to '#' to show pointer cursor - if ((menu as any).onClick) { - item.onClick = (e: React.MouseEvent) => { - e.preventDefault(); - (menu as any).onClick(e); - }; - item.href = '#'; - } else if (menu.path && menu.path !== '#') { - // Only set path when path is not '#' - item.path = menu.path; - } - + return item; }); } - + const [open, setOpen] = useState(false); const handleOpenChange = (open: boolean) => { setOpen(open); @@ -179,9 +183,9 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { overlayClassName={styles.userDropdown} > - - {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(user.username.length, -2) : user.username[0]} - + {user.username && + {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(-2) : user.username[0]} + } {user.username}
void; } -const ModelSelect: FC = ({ - params, - placeholder, - fontClassName, - isAutoFetch = true, - initialData = [], - ...props -}) => { +const ModelSelect: FC = ({ params, placeholder, fontClassName, isAutoFetch = true, initialData = [], updateOptions, ...props }) => { const { t } = useTranslation(); const [options, setOptions] = useState([]); @@ -60,6 +54,10 @@ const ModelSelect: FC = ({ ); }; + useEffect(() => { + if (updateOptions) updateOptions([...options, ...initialData]); + }, [options, initialData]) + return ( { + form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'sub_variable_condition', 'conditions', subIndex], { + key: value, + input_type: value === 'size' ? 'constant' : undefined, + value: undefined, + operator: value === 'size' ? 'ge' : 'eq', + }); + }} + /> + + + + + { handleInputTypeChange(caseIndex, conditionIndex, subIndex); }} + className="rb:w-20!" + /> + + + + {subInputType === 'variable' + ? + : { form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], value); }} + /> + } + + + : + {subLeft === 'type' + ? handleInputTypeChange(caseIndex, conditionIndex)} className="rb:w-20!" /> - + {inputType === 'variable' ? = ({ {['boolean', 'array[boolean]'].includes(leftFieldType as string) ? - : + : } ) diff --git a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx index 2a976bf0..bd62c490 100644 --- a/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx +++ b/web/src/views/Workflow/components/Properties/CodeExecution/index.tsx @@ -94,7 +94,7 @@ const CodeExecution: FC = ({ options }) => { { label: 'JAVASCRIPT', value: 'javascript' } ]} popupMatchSelectWidth={false} - className={`rb:font-medium! rb:w-25! rb:h-4! rb:p-0! ${styles.editor}`} + className={`rb:font-medium! rb:w-25! rb:h-4! rb:py-0! rb:px-2! ${styles.editor}`} onChange={handleChangeLanguage} variant="borderless" /> diff --git a/web/src/views/Workflow/components/Properties/ConditionList/index.tsx b/web/src/views/Workflow/components/Properties/ConditionList/index.tsx index ddf92971..6ca0fb05 100644 --- a/web/src/views/Workflow/components/Properties/ConditionList/index.tsx +++ b/web/src/views/Workflow/components/Properties/ConditionList/index.tsx @@ -155,7 +155,9 @@ const ConditionList: FC = ({ const currentExpression = expressions[index] || {}; const currentOperator = currentExpression.operator; const leftFieldValue = currentExpression.left; - const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue); + const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue) + ?? options.flatMap(o => o.children ?? []).find(child => `{{${child.value}}}` === leftFieldValue) + ?? options.flatMap(o => o.children ?? []).flatMap((c: any) => c.children ?? []).find((gc: any) => `{{${gc.value}}}` === leftFieldValue); const leftFieldType = leftFieldOption?.dataType; const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty' || ['array[object]', 'object'].includes(leftFieldType as string); const operatorList = leftFieldType && ['array[object]', 'object'].includes(leftFieldType) @@ -176,7 +178,7 @@ const ConditionList: FC = ({ className="rb:mb-2!" >
- @@ -216,7 +218,7 @@ const ConditionList: FC = ({ {!hideRightField && ( -
+
{leftFieldType === 'number' ? ( diff --git a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx index 5d1138f0..ce37743b 100644 --- a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx @@ -155,7 +155,7 @@ const CycleVarsList: FC = ({ ? : ( diff --git a/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx b/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx index 0100707c..24cdc89a 100644 --- a/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx +++ b/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx @@ -62,14 +62,18 @@ const GroupVariableList: FC = ({ */ useEffect(() => { if (!isCanAdd && value[0]) { - const firstVariable = options.find(opt => `{{${opt.value}}}` === value[0]); + const firstVariable = options.find(opt => `{{${opt.value}}}` === value[0]) + ?? options.flatMap(o => o.children ?? []).find(c => `{{${c.value}}}` === value[0]) + ?? options.flatMap(o => o.children ?? []).flatMap((c: any) => c.children ?? []).find((gc: any) => `{{${gc.value}}}` === value[0]); if (firstVariable) { form.setFieldValue(['group_type', 'output'], firstVariable.dataType); } } else if (isCanAdd) { value.forEach((item: any, index: number) => { if (item?.value?.[0]) { - const firstVariable = options.find(opt => `{{${opt.value}}}` === item.value[0]); + const firstVariable = options.find(opt => `{{${opt.value}}}` === item.value[0]) + ?? options.flatMap(o => o.children ?? []).find(c => `{{${c.value}}}` === item.value[0]) + ?? options.flatMap(o => o.children ?? []).flatMap((c: any) => c.children ?? []).find((gc: any) => `{{${gc.value}}}` === item.value[0]); if (firstVariable) { form.setFieldValue(['group_type', index], firstVariable.dataType); } diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx index e0a27b47..e4b2cc29 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx @@ -85,9 +85,9 @@ const EditableTable: FC = ({ return [ { title: t('workflow.config.name'), - dataIndex: 'name', + dataIndex: 'key', render: (_: any, __: TableRow, index: number) => ( - + { - const value = e.target.value || e.target.value - form.setFieldValue(['body', 'data'], ['form-data', 'x-www-form-urlencoded'].includes(value) ? [{}] : undefined) + const handleChangeBodyContentType = () => { + form.setFieldValue(['body', 'data'], undefined) } // Handle error handling method change and update node ports accordingly diff --git a/web/src/views/Workflow/components/Properties/ListOperator/FilterConditions/index.tsx b/web/src/views/Workflow/components/Properties/ListOperator/FilterConditions/index.tsx index 95c2e113..9799dc24 100644 --- a/web/src/views/Workflow/components/Properties/ListOperator/FilterConditions/index.tsx +++ b/web/src/views/Workflow/components/Properties/ListOperator/FilterConditions/index.tsx @@ -56,7 +56,7 @@ const operatorsObj: { [key: string]: SelectProps['options'] } = { ] } -const typeOptions = ['image', 'document', 'video', 'audio'] +export const typeOptions = ['image', 'document', 'video', 'audio'] const FilterConditions: FC = ({ options, @@ -101,24 +101,20 @@ const FilterConditions: FC = ({ align="start" className="rb:mb-2!" > -
+
{variableType === 'array[file]' && - - - - handleKeyFieldChange(index, value)} + className="rb:w-full! select rb:mb-1!" + variant="borderless" + /> + } - + ({ value: vo, label: t(`application.${vo}`) } ))} - variant="borderless" - className="rb:w-full!" + variant="filled" /> : { if (vo.dataType === keyFieldType) return [vo]; const filteredChildren = vo.children?.filter(sub => sub.dataType === keyFieldType); @@ -167,7 +162,7 @@ const FilterConditions: FC = ({
remove(field.name)} >
diff --git a/web/src/views/Workflow/components/Properties/MappingList/index.tsx b/web/src/views/Workflow/components/Properties/MappingList/index.tsx index f46d6114..1f609445 100644 --- a/web/src/views/Workflow/components/Properties/MappingList/index.tsx +++ b/web/src/views/Workflow/components/Properties/MappingList/index.tsx @@ -58,7 +58,7 @@ const MappingList: FC = ({ label, name, options, extra, valueK placeholder={t('common.pleaseSelect')} options={options} size="small" - className="rb:w-51!" + className="rb:flex-1!" />
{ const { t } = useTranslation() const form = Form.useFormInstance() const model_id = Form.useWatch(['model_id'], form) - console.log('ModelConfig', model_id) + const [selectedModel, setSelectedModel] = useState(null) + const [options, setOptions] = useState([]) + + const updateOptions = (options: Model[]) => { + setOptions(options) + } + + useEffect(() => { + if (model_id && options) { + const model = options.find(item => item.id === model_id) + setSelectedModel(model || null) + form.setFieldValue('json_output', false) + } else { + setSelectedModel(null) + } + }, [model_id, options]) return ( <> @@ -25,6 +47,7 @@ const ModelConfig: FC = () => { params={{ type: 'llm,chat' }} className="rb:w-full!" size="small" + updateOptions={updateOptions} /> {model_id && ( @@ -52,7 +75,7 @@ const ModelConfig: FC = () => { { className="rb:-mt-2!" /> + )} diff --git a/web/src/views/Workflow/components/Properties/VariableSelect.tsx b/web/src/views/Workflow/components/Properties/VariableSelect.tsx index b28d7b4f..b749d3b2 100644 --- a/web/src/views/Workflow/components/Properties/VariableSelect.tsx +++ b/web/src/views/Workflow/components/Properties/VariableSelect.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:40:13 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-08 10:48:21 + * @Last Modified time: 2026-04-16 13:57:30 */ import { useState, useRef, useEffect, useLayoutEffect, type FC } from 'react' import { createPortal } from 'react-dom' @@ -40,15 +40,34 @@ const VariableSelect: FC = ({ const { t } = useTranslation(); const [open, setOpen] = useState(false); const [search, setSearch] = useState(''); - const [expandedParent, setExpandedParent] = useState(null); + const [expandedParentKey, setExpandedParentKey] = useState(null); + const [activeIndex, setActiveIndex] = useState(-1); + const [activePanel, setActivePanel] = useState<'main' | 'child'>('main'); + const [childActiveIndex, setChildActiveIndex] = useState(-1); const [dropdownPos, setDropdownPos] = useState({ top: 0, left: 0, width: 0 }); const [childPanelPos, setChildPanelPos] = useState({ top: 0, right: 0 }); const containerRef = useRef(null); const dropdownRef = useRef(null); const itemRefs = useRef>(new Map()); + const childItemRefs = useRef>(new Map()); + const activeKeyRef = useRef(null); const CHILD_PANEL_HEIGHT = 280; // max-h-60 (240) + header (~40) + const calcChildPos = (key: string) => { + const el = itemRefs.current.get(key); + if (!el) return; + const rect = el.getBoundingClientRect(); + const dropdownEl = dropdownRef.current; + if (!dropdownEl) return; + const dropdownRect = dropdownEl.getBoundingClientRect(); + const dropdownBottom = dropdownRect.bottom; + const actualChildHeight = Math.min(CHILD_PANEL_HEIGHT, dropdownRect.height); + // Bottom-align child panel with main panel + const top = Math.max(10, dropdownBottom - actualChildHeight); + setChildPanelPos({ top, right: window.innerWidth - rect.left + 8 }); + }; + // Calculate dropdown position (runs synchronously after DOM paint to avoid flicker) useLayoutEffect(() => { if (!open || !containerRef.current) return; @@ -69,7 +88,9 @@ const VariableSelect: FC = ({ ? triggerRect.bottom + MARGIN : Math.max(MARGIN, triggerRect.top - dropdownHeight - MARGIN); setDropdownPos({ top, left, width }); - }, [open, search, Array.isArray(value) ? value.length : 0]); + // Re-calculate child panel position if expanded + if (expandedParentKey) calcChildPos(expandedParentKey); + }, [open, search, Array.isArray(value) ? value.length : 0, options.length, expandedParentKey]); const filteredOptions = filterBooleanType ? options.filter(o => o.dataType !== 'boolean') @@ -84,6 +105,10 @@ const VariableSelect: FC = ({ ? filteredOptions.find(o => o.children?.some(c => `{{${c.value}}}` === value)) : undefined; + const expandedParent = expandedParentKey + ? filteredOptions.find(o => o.key === expandedParentKey) ?? null + : null; + const groupedSuggestions = filteredOptions.reduce((groups: Record, s) => { const nodeId = s.nodeData.id as string; if (!groups[nodeId]) groups[nodeId] = []; @@ -103,6 +128,12 @@ const VariableSelect: FC = ({ }, {}) : groupedSuggestions; + useEffect(() => { + if (!expandedParentKey) return; + calcChildPos(expandedParentKey); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [dropdownPos, expandedParentKey]); + useEffect(() => { if (!open) return; const updatePos = () => { @@ -139,7 +170,7 @@ const VariableSelect: FC = ({ ) { setOpen(false); setSearch(''); - setExpandedParent(null); + setExpandedParentKey(null); setChildPanelPos({ top: 0, right: 0 }); } }; @@ -147,6 +178,87 @@ const VariableSelect: FC = ({ return () => document.removeEventListener('mousedown', handler); }, [open]); + // Flat list of all visible selectable items (main panel only, no children expanded inline) + const flatItems = Object.values(filteredGroups).flat(); + + useEffect(() => { + setActiveIndex(-1); + setActivePanel('main'); + setChildActiveIndex(-1); + }, [open, search]); + + useEffect(() => { + if (activeIndex < 0 || activeIndex >= flatItems.length) { + setExpandedParentKey(null); + return; + } + const s = flatItems[activeIndex]; + activeKeyRef.current = s.key; + itemRefs.current.get(s.key)?.scrollIntoView({ block: 'nearest' }); + if (s.children?.length) { + calcChildPos(s.key); + setExpandedParentKey(s.key); + } else { + setExpandedParentKey(null); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [activeIndex]); + + useEffect(() => { + if (!expandedParent?.children?.length || childActiveIndex < 0) return; + const child = expandedParent.children[childActiveIndex]; + if (child) childItemRefs.current.get(child.key)?.scrollIntoView({ block: 'nearest' }); + }, [childActiveIndex, expandedParent]); + + useEffect(() => { + if (!open) return; + const handler = (e: KeyboardEvent) => { + const children = expandedParent?.children ?? []; + if (activePanel === 'child') { + if (e.key === 'ArrowDown') { + e.preventDefault(); + setChildActiveIndex(i => Math.min(i + 1, children.length - 1)); + } else if (e.key === 'ArrowUp') { + e.preventDefault(); + setChildActiveIndex(i => Math.max(i - 1, 0)); + } else if (e.key === 'ArrowRight') { + e.preventDefault(); + setActivePanel('main'); + setChildActiveIndex(-1); + } else if (e.key === 'Enter' && childActiveIndex >= 0 && childActiveIndex < children.length) { + e.preventDefault(); + const child = children[childActiveIndex]; + if (!child.disabled) handleSelect(child); + } else if (e.key === 'Escape') { + setOpen(false); + } + } else { + if (e.key === 'ArrowDown') { + e.preventDefault(); + setActiveIndex(i => Math.min(i + 1, flatItems.length - 1)); + } else if (e.key === 'ArrowUp') { + e.preventDefault(); + setActiveIndex(i => Math.max(i - 1, 0)); + } else if (e.key === 'ArrowLeft') { + e.preventDefault(); + if (expandedParent?.children?.length) { + setActivePanel('child'); + setChildActiveIndex(0); + } + } else if (e.key === 'Enter' && activeIndex >= 0 && activeIndex < flatItems.length) { + e.preventDefault(); + const s = flatItems[activeIndex]; + if (!s.disabled) handleSelect(s); + } else if (e.key === 'Escape') { + setOpen(false); + } + } + }; + document.addEventListener('keydown', handler); + return () => document.removeEventListener('keydown', handler); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [open, activeIndex, activePanel, childActiveIndex, flatItems, expandedParent]); + const handleSelect = (suggestion: Suggestion) => { if (multiple) { const key = `{{${suggestion.value}}}`; @@ -159,7 +271,7 @@ const VariableSelect: FC = ({ onChange?.(`{{${suggestion.value}}}`, suggestion); setOpen(false); setSearch(''); - setExpandedParent(null); + setExpandedParentKey(null); } }; @@ -167,19 +279,6 @@ const VariableSelect: FC = ({ e.stopPropagation(); onChange?.(multiple ? [] : '', multiple ? [] : undefined); }; - - const updateChildPos = (key: string) => { - const el = itemRefs.current.get(key); - if (el) { - const rect = el.getBoundingClientRect(); - const spaceBelow = window.innerHeight - rect.top - 10; - const top = spaceBelow >= CHILD_PANEL_HEIGHT - ? rect.top - : Math.max(10, window.innerHeight - CHILD_PANEL_HEIGHT - 10); - setChildPanelPos({ top, right: window.innerWidth - rect.left + 8 }); - } - }; - const sep = /; const isConversation = (parentOfSelected ?? selectedSuggestion)?.group === 'CONVERSATION' || (selectedSuggestion ? filteredOptions.some(o => o.group === 'CONVERSATION' && o.children?.some(c => `{{${c.value}}}` === value)) : false); @@ -190,20 +289,30 @@ const VariableSelect: FC = ({ {/* Trigger */}
setOpen(o => !o)} > {multiple ? ( selectedValues.length > 0 ? ( - + {selectedValues.map(v => { const s = suggestionMap.get(v); if (!s) return null; @@ -214,11 +323,11 @@ const VariableSelect: FC = ({ return ( - {!isConv && nd?.icon &&
} + {!isConv && nd?.icon &&
} {!isConv && nd?.name && {nd.name}{sep}} - + {parent ? <>{parent.label}{sep}{s.label} : s.label} = ({ ); })} - + ) : ( - {placeholder} + {placeholder} ) ) : selectedSuggestion ? (
- - {!isConversation && nodeData?.icon &&
} - {!isConversation && nodeData?.name && {nodeData.name}} - {!isConversation && nodeData?.name && {sep}} - + + {!isConversation && nodeData?.icon &&
} + {!isConversation && nodeData?.name && {nodeData.name}} + {!isConversation && nodeData?.name && {sep}} + {parentOfSelected ? <>{parentOfSelected.label}{sep}{selectedSuggestion.label} : selectedSuggestion.label}
) : ( - {placeholder} + {placeholder} )} {allowClear && ( @@ -266,18 +377,19 @@ const VariableSelect: FC = ({ {open && createPortal(
-
- {Object.entries(filteredGroups).map(([nodeId, suggestions]) => { +
+ {Object.entries(filteredGroups).map(([nodeId, suggestions], index) => { const nd = suggestions[0].nodeData; return ( -
- - {nd.icon &&
} +
+
{nd.name} - +
{suggestions.map(s => { const isSelected = multiple ? selectedValues.includes(`{{${s.value}}}`) @@ -288,11 +400,9 @@ const VariableSelect: FC = ({ { if (el) itemRefs.current.set(s.key, el); }} - className={clsx("rb:pl-6! rb:pr-3! rb:py-1.25! rb:rounded-lg!", { - 'rb:bg-[#e6f4ff]': isSelected || isExpanded, - 'rb:bg-white rb:hover:bg-[#F6F6F6]!': !(isSelected || isExpanded), - 'rb:opacity-60': s.disabled, - 'rb:cursor-not-allowed': s.disabled, + className={clsx("rb:px-2! rb:py-0.75! rb:rounded-sm rb:leading-4.5 rb:text-[#5B6167] rb:hover:bg-[#F6F6F6]", { + 'rb:bg-[#F6F6F6]': isSelected || isExpanded || flatItems.indexOf(s) === activeIndex, + 'rb:cursor-not-allowed rb:opacity-65': s.disabled, 'rb:cursor-pointer': !s.disabled, })} align="center" @@ -300,30 +410,29 @@ const VariableSelect: FC = ({ onClick={() => { if (s.disabled) return; if (hasChildren) { - updateChildPos(s.key); - setExpandedParent(prev => prev?.key === s.key ? null : s); + calcChildPos(s.key); + setExpandedParentKey(prev => prev === s.key ? null : s.key); } handleSelect(s); }} onMouseEnter={() => { if (hasChildren) { - updateChildPos(s.key); - setExpandedParent(s); + calcChildPos(s.key); + setExpandedParentKey(s.key); } else { - setExpandedParent(null); + setExpandedParentKey(null); } }} > - +
{multiple && ( - + )} - {`{x}`} - {s.label} - - - {s.dataType && {s.dataType}} + {`{x}`} {s.label} +
+ + {s.dataType && {s.dataType}} {hasChildren &&
}
@@ -334,7 +443,7 @@ const VariableSelect: FC = ({ })} {Object.keys(filteredGroups).length === 0 && (
- {t('workflow.variableSelect.empty', '暂无变量')} + {t('workflow.variableSelect.empty')}
)}
@@ -346,51 +455,43 @@ const VariableSelect: FC = ({ {open && expandedParent?.children?.length && createPortal(
setExpandedParent(expandedParent)} + onMouseEnter={() => setExpandedParentKey(expandedParentKey)} > -
!expandedParent.disabled && handleSelect(expandedParent)} - > +
- - {expandedParent.nodeData.name}.{expandedParent.label} - + {expandedParent.nodeData.name}.{expandedParent.label} {expandedParent.dataType}
- {expandedParent.children.map(child => { + {expandedParent.children.map((child, ci) => { const isSelected = multiple ? selectedValues.includes(`{{${child.value}}}`) : `{{${child.value}}}` === value; - const hasGrandChildren = !!child.children?.length; + const isChildActive = activePanel === 'child' && ci === childActiveIndex; return ( { if (el) childItemRefs.current.set(child.key, el); }} + className={clsx("rb:px-2! rb:py-0.75! rb:rounded-sm rb:leading-4.5 rb:text-[#5B6167] rb:hover:bg-[#F6F6F6]", { + 'rb:bg-[#F6F6F6]': isSelected || isChildActive, + 'rb:cursor-not-allowed rb:opacity-65': child.disabled, + 'rb:cursor-pointer': !child.disabled, })} align="center" justify="space-between" - style={{ - cursor: child.disabled ? 'not-allowed' : 'pointer', - opacity: child.disabled ? 0.5 : 1, - }} onClick={() => !child.disabled && handleSelect(child)} > - + {multiple && ( )} - {child.label} - - - {child.dataType && {child.dataType}} - {hasGrandChildren && } + {child.label} + + {child.dataType && {child.dataType}} + ); })} diff --git a/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts index 3c4ea6f7..14dcced2 100644 --- a/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts +++ b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-01-19 17:00:26 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-08 10:12:27 + * @Last Modified time: 2026-04-13 10:44:17 */ /** * useVariableList Hook @@ -414,7 +414,7 @@ export const useVariableList = ( const pd = parentLoop.getData(); const pid = pd.id; if (pd.type === 'loop') { - (pd.cycle_vars || []).forEach((cv: any) => addVariable(list, keys, `${pid}_cycle_${cv.name}`, cv.name, cv.type || 'String', `${pid}.${cv.name}`, pd)); + (pd.cycle_vars || []).forEach((cv: any) => addVariable(list, keys, `${pid}_cycle_${cv.name}`, cv.name, cv.type || 'string', `${pid}.${cv.name}`, pd)); } else if (pd.type === 'iteration' && pd.config.input.defaultValue) { let itemType = 'object'; const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index b5bc2d2e..f826edd9 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:39:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-10 17:24:19 + * @Last Modified time: 2026-04-13 10:44:19 */ import { type FC, useEffect, useState, useMemo } from "react"; import clsx from 'clsx' @@ -266,7 +266,7 @@ const Properties: FC = ({ key, label: cycleVar.name, type: 'variable', - dataType: cycleVar.type || 'String', + dataType: cycleVar.type || 'string', value: `${parentNodeId}.${cycleVar.name}`, nodeData: parentData, }); @@ -643,7 +643,7 @@ const Properties: FC = ({ key: contextKey, label: 'context', type: 'variable', - dataType: 'String', + dataType: 'string', value: `context`, nodeData: selectedNode.getData(), isContext: true, @@ -791,7 +791,7 @@ const Properties: FC = ({ key: `${selectedNode.id}_cycle_${cycleVar.name}`, label: cycleVar.name, type: 'variable', - dataType: cycleVar.type || 'String', + dataType: cycleVar.type || 'string', value: `${selectedNode.getData().id}.${cycleVar.name}`, nodeData: selectedNode.getData(), })); diff --git a/web/src/views/Workflow/components/Properties/properties.module.css b/web/src/views/Workflow/components/Properties/properties.module.css index 66da00bf..58b21591 100644 --- a/web/src/views/Workflow/components/Properties/properties.module.css +++ b/web/src/views/Workflow/components/Properties/properties.module.css @@ -23,6 +23,11 @@ } .properties :global(.select.ant-select-single.ant-select-sm.ant-select-borderless) { height: 28px; + border: 1px solid #F6F6F6; + border-radius: 8px; +} +.properties :global(.select.ant-select-single.ant-select-sm.ant-select-borderless.ant-select-focused) { + border: 1px solid #171719; } .properties :global(.ant-table-wrapper .ant-table-thead>tr>th), .properties :global(.ant-table-wrapper .ant-table-thead>tr>td), @@ -157,4 +162,7 @@ padding-inline-start: 0px; border-radius: 4px; margin-block: 0px; +} +.properties :global(.ant-input-number-affix-wrapper) { + font-size: 12px; } \ No newline at end of file diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index eed77e2c..cae20180 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:06:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-07 19:56:56 + * @Last Modified time: 2026-04-16 17:52:30 */ import LoopNode from './components/Nodes/LoopNode'; import NormalNode from './components/Nodes/NormalNode'; @@ -101,6 +101,10 @@ export const nodeLibrary: NodeLibrary[] = [ step: 1, defaultValue: 2000 }, + json_output: { + type: 'define', + defaultValue: false + }, context: { type: 'variableList', placeholder: 'workflow.config.llm.contextPlaceholder' diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index f385acf3..e82ad580 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,9 +2,10 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-07 23:17:50 + * @Last Modified time: 2026-04-20 16:00:26 */ -import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, type Edge } from '@antv/x6'; +import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6'; +import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type'; import { register } from '@antv/x6-react-shape'; import type { PortMetadata } from '@antv/x6/lib/model/port'; import { App } from 'antd'; @@ -18,6 +19,7 @@ import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types'; import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant'; import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types'; import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils'; +import { useWorkflowStore } from '@/store/workflow'; /** * Props for useWorkflowGraph hook @@ -63,6 +65,14 @@ export interface UseWorkflowGraphReturn { copyEvent: () => boolean | void; /** Handler for paste keyboard event */ parseEvent: () => boolean | void; + /** Whether undo is available */ + canUndo: boolean; + /** Whether redo is available */ + canRedo: boolean; + /** Undo last action */ + undo: () => void; + /** Redo last undone action */ + redo: () => void; /** Function to save workflow configuration */ handleSave: (flag?: boolean) => Promise; /** Chat variables for workflow */ @@ -94,6 +104,8 @@ export const useWorkflowGraph = ({ const { message } = App.useApp(); const { t } = useTranslation() const { user } = useUser(); + const { chatHistoryMap } = useWorkflowStore() + const chatHistory = Object.values(chatHistoryMap).at(-1) ?? [] // Refs const graphRef = useRef(); @@ -105,12 +117,15 @@ export const useWorkflowGraph = ({ const [config, setConfig] = useState(null); const [chatVariables, setChatVariables] = useState([]) const featuresRef = useRef(undefined) + const [canUndo, setCanUndo] = useState(false) + const [canRedo, setCanRedo] = useState(false) useEffect(() => { if (!graphRef.current) return graphRef.current.getNodes().forEach(node => { const data = node.getData() if (data?.type === 'if-else' || data?.type === 'question-classifier') { + console.log('chatVariables', chatVariables) node.setData({ ...data, chatVariables }, { silent: true }) } }) @@ -203,7 +218,7 @@ export const useWorkflowGraph = ({ ? Object.entries(group_variables as Record).map(([key, value]) => ({ key, value })) : group_variables } else if (type === 'http-request' && (key === 'headers' || key === 'params') && config[key] && typeof config[key] === 'object' && !Array.isArray(config[key]) && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { - nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([name, value]) => ({ name, value })) + nodeLibraryConfig.config[key].defaultValue = Object.entries(config[key]).map(([key, value]) => ({ key, value })) } else if (type === 'code' && key === 'code' && config[key] && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { try { nodeLibraryConfig.config[key].defaultValue = decodeURIComponent(atob(config[key] as string)) @@ -469,6 +484,8 @@ export const useWorkflowGraph = ({ graphRef.current.getNodes().forEach(node => { if (node.getData()?.cycle) node.toFront(); }); + graphRef.current.enableHistory() + graphRef.current.cleanHistory() } }, 200) } @@ -504,6 +521,22 @@ export const useWorkflowGraph = ({ global: true, }), ); + graphRef.current.use( + new History({ + enabled: false, + beforeAddCommand(_event, args: any) { + const event = args?.key ? `cell:change:${args.key}` : _event; + if (event.startsWith('cell:change:') && + event !== 'cell:change:position' && + event !== 'cell:change:source' && + event !== 'cell:change:target') return false; + }, + }), + ); + graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => { + setCanUndo(graphRef.current?.canUndo() ?? false) + setCanRedo(graphRef.current?.canRedo() ?? false) + }) }; // 显示/隐藏连接桩 // const showPorts = (show: boolean) => { @@ -1022,24 +1055,39 @@ export const useWorkflowGraph = ({ graphRef.current.on('node:removed', blankClick) // When edge connected, bring connected nodes' ports to front - graphRef.current.on('edge:connected', ({ isNew }) => { - // Bring edge to front first, then bring child nodes above edges - // Parent (loop/iteration) nodes stay behind to avoid covering edges - // Reset any port hover state left from dragging + graphRef.current.on('edge:connected', ({ isNew, edge }) => { if (isNew) { - graphRef.current?.getNodes().forEach(node => { - if (!node.getData()?.cycle) node.toFront(); - }); - graphRef.current?.getEdges().forEach(edge => { - const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId()); - const targetCell = graphRef.current?.getCellById(edge.getTargetCellId()); - if (sourceCell?.getData()?.cycle || targetCell?.getData()?.cycle) { - edge.toFront(); - } - }); - graphRef.current?.getNodes().forEach(node => { - if (node.getData()?.cycle) node.toFront(); - }); + const sourceCellId = edge.getSourceCellId() + const targetCellId = edge.getTargetCellId() + const sourceCell = graphRef.current?.getCellById(sourceCellId); + const targetCell = graphRef.current?.getCellById(targetCellId); + + sourceCell?.toFront(); + targetCell?.toFront() + if (['loop', 'iteration'].includes(sourceCell?.getData()?.type)) { + graphRef.current?.getEdges().forEach(edge => { + const edgeSourceCell = graphRef.current?.getCellById(edge.getSourceCellId()); + const edgeTargetCell = graphRef.current?.getCellById(edge.getTargetCellId()); + if (edgeSourceCell?.getData()?.cycle === sourceCellId || edgeTargetCell?.getData()?.cycle === sourceCellId) { + edge.toFront(); + } + }); + graphRef.current?.getNodes().forEach(node => { + if (node.getData()?.cycle === sourceCellId) node.toFront(); + }); + } + if (['loop', 'iteration'].includes(targetCell?.getData()?.type)) { + graphRef.current?.getEdges().forEach(edge => { + const edgeSourceCell = graphRef.current?.getCellById(edge.getSourceCellId()); + const edgeTargetCell = graphRef.current?.getCellById(edge.getTargetCellId()); + if (edgeSourceCell?.getData()?.cycle === targetCellId || edgeTargetCell?.getData()?.cycle === targetCellId) { + edge.toFront(); + } + }); + graphRef.current?.getNodes().forEach(node => { + if (node.getData()?.cycle === targetCellId) node.toFront(); + }); + } } }); @@ -1077,6 +1125,9 @@ export const useWorkflowGraph = ({ graphRef.current.bindKey(['ctrl+v', 'cmd+v'], parseEvent); // Delete selected nodes and edges graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent); + // Undo / Redo + graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; }); + graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; }); }; @@ -1184,9 +1235,6 @@ export const useWorkflowGraph = ({ }) || []; const edges = graphRef.current?.getEdges() || [] - - console.log('config', config) - const params = { ...config, features: featuresRef.current, @@ -1243,9 +1291,17 @@ export const useWorkflowGraph = ({ itemConfig[key] = {} if (value.length > 0) { value.forEach((vo: any) => { - itemConfig[key][vo.name] = vo.value + itemConfig[key][vo.key] = vo.value }) } + } else if (data.type === 'http-request' && key === 'body' && data.config[key] && 'defaultValue' in data.config[key]) { + const value = data.config[key].defaultValue + itemConfig[key] = value + if (value.content_type === 'json' && value.data && value.data !== '') { + itemConfig[key].data = value.data.replace(/\u00a0/g, ' ') + } else { + itemConfig[key].data = value.data + } } else if (data.config[key] && 'defaultValue' in data.config[key] && key !== 'knowledge_retrieval') { itemConfig[key] = data.config[key].defaultValue } else if (key === 'knowledge_retrieval' && data.config[key] && 'defaultValue' in data.config[key]) { @@ -1390,6 +1446,9 @@ export const useWorkflowGraph = ({ return userVars } + const undo = () => graphRef.current?.undo() + const redo = () => graphRef.current?.redo() + const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => { const { statement = '' } = value?.opening_statement || {} featuresRef.current = value @@ -1425,6 +1484,31 @@ export const useWorkflowGraph = ({ } } } + useEffect(() => { + if (!graphRef.current) return; + const nodes = graphRef.current.getNodes(); + + const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length); + // Reset all node execution status first + nodes.forEach(node => { + const data = node.getData(); + if (typeof data.status === 'string') { + node.setData({ ...data, executionStatus: undefined }); + } + }); + if (!lastWithSub?.subContent) return; + // Build a nodeId -> status map first + const statusMap: Record = {}; + lastWithSub.subContent.forEach(sub => { + if (typeof sub.status === 'string') { + statusMap[sub.node_id] = sub.status; + const node = nodes.find(n => n.getData()?.id === sub.node_id); + if (node) { + node.setData({ ...node.getData(), executionStatus: sub.status }); + } + } + }); + }, [chatHistory, graphRef.current]); return { config, @@ -1449,5 +1533,9 @@ export const useWorkflowGraph = ({ handleSaveFeaturesConfig, features: featuresRef.current, getStartNodeVariables, + canUndo, + canRedo, + undo, + redo, }; }; diff --git a/web/src/views/Workflow/index.tsx b/web/src/views/Workflow/index.tsx index 26d7420c..f98cf308 100644 --- a/web/src/views/Workflow/index.tsx +++ b/web/src/views/Workflow/index.tsx @@ -39,6 +39,10 @@ const Workflow = forwardRef { @@ -96,6 +100,10 @@ const Workflow = forwardRef
diff --git a/web/src/views/Workflow/utils.ts b/web/src/views/Workflow/utils.ts index 67a913f3..bd81b6eb 100644 --- a/web/src/views/Workflow/utils.ts +++ b/web/src/views/Workflow/utils.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-03-24 15:07:49 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 15:07:49 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-04-17 20:40:47 */ import { portItemArgsY, conditionNodePortItemArgsY, conditionNodeHeight } from './constant' @@ -22,11 +22,31 @@ import { portItemArgsY, conditionNodePortItemArgsY, conditionNodeHeight } from ' * @param cases - Array of case objects, each containing an `expressions` array. * @returns The total pixel height for the condition node. */ +export const isSubExprSet = (sub: any) => { + if (!sub?.key) return false; + if (['not_empty', 'empty'].includes(sub?.operator)) return true; + return !!sub.value || typeof sub.value === 'boolean' || typeof sub.value === 'number'; +}; + +const getEffectiveExprCount = (expr: any): number => { + const subs = expr?.sub_variable_condition?.conditions; + if (subs?.length && subs.every(isSubExprSet)) return 1 + subs.length; + if (subs?.length > 0) { + return 2 + } + return 1; +}; + export const calcConditionNodeTotalHeight = (cases: any[]) => { - // Total number of expressions across all cases - const exprCount = cases.reduce((acc: number, c: any) => acc + (c?.expressions?.length || 0), 0); - // Sum of expression counts only for cases that have more than one expression - const hasMultiExprCount = cases.reduce((acc: number, c: any) => acc + (c?.expressions?.length > 1 ? c?.expressions?.length : 0), 0); + // Total number of effective expression rows (sub_variable_condition expand height when all set) + const exprCount = cases.reduce((acc: number, c: any) => + acc + (c?.expressions?.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0) || 0), 0); + // Sum of effective expression counts only for cases that have more than one expression + const hasMultiExprCount = cases.reduce((acc: number, c: any) => { + if (!c?.expressions?.length || c.expressions.length <= 1) return acc; + const effectiveCount = c.expressions.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0); + return acc + effectiveCount; + }, 0); return conditionNodeHeight + (cases.length - 1) * 26 + exprCount * 20 + hasMultiExprCount * 3; }; @@ -68,17 +88,44 @@ export const getConditionNodeCasePortY = (cases: any[], caseIndex: number) => { let singleExprCount = 0; let multiExprCount = 0; let extraExprs = 0; + let portItemArgsYNum = 0; for (let i = 0; i < caseIndex; i++) { + const notHasSub = cases[i]?.expressions?.filter((e: any) => !e?.sub_variable_condition?.conditions || e?.sub_variable_condition?.conditions.length <1).length const n = cases[i]?.expressions?.length || 0; - y += portItemArgsY * (n + 1); - if (n === 1) singleExprCount++; - else if (n >= 2) { + let casePortItemArgsYNum = n + 1; + // Add extra y for expressions with all sub_variable_condition set + cases[i]?.expressions?.forEach((expr: any) => { + const subs = expr?.sub_variable_condition?.conditions; + if (subs?.length && subs.every(isSubExprSet)) { + casePortItemArgsYNum += subs.length; + } else if (subs?.length) { + casePortItemArgsYNum += 1 + } + }); + portItemArgsYNum += casePortItemArgsYNum; + if (n === 1 && !cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) { + singleExprCount++ + } else if (n >= 2 || cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) { multiExprCount++; - if (n > 2) extraExprs += n - 2; + cases[i]?.expressions?.forEach((e: any) => { + const subs = e?.sub_variable_condition?.conditions; + if (subs?.length && subs.every(isSubExprSet) && subs.length > 1) { + extraExprs += subs.length + 2; + } + }); + + console.log('extraExprs notHasSub', notHasSub) + if (notHasSub > 3) { + extraExprs += n - 2 + notHasSub/4; + } else { + extraExprs += n - 2 + notHasSub/4 + } } } + console.log('singleExprCount', singleExprCount, 'multiExprCount', multiExprCount, 'extraExprs', extraExprs) + y += portItemArgsY * portItemArgsYNum // Correction for single-expression cases (slightly shorter rendered height) if (singleExprCount > 0) y -= singleExprCount * 7 + 2; // Correction for multi-expression cases (compact logical operator row) diff --git a/web/vite.config.ts b/web/vite.config.ts index 8cc1fa3b..4a1a0b34 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -3,6 +3,7 @@ import react from '@vitejs/plugin-react' import { resolve } from 'path' import AutoImport from 'unplugin-auto-import/vite' import tailwindcss from '@tailwindcss/vite' +import svgr from 'vite-plugin-svgr'; // https://vite.dev/config/ export default defineConfig({ @@ -32,6 +33,7 @@ export default defineConfig({ imports: ['react', 'react-router-dom'], dts: 'public/auto-imports.d.ts', }), + svgr({ svgrOptions: { icon: true } }), ], css: { modules: {