diff --git a/.github/workflows/release-notify-wechat.yml b/.github/workflows/release-notify-wechat.yml new file mode 100644 index 00000000..935d84d5 --- /dev/null +++ b/.github/workflows/release-notify-wechat.yml @@ -0,0 +1,164 @@ +name: Release Notify Workflow + +on: + pull_request: + types: [closed] + +jobs: + notify: + if: > + github.event.pull_request.merged == true && + startsWith(github.event.pull_request.base.ref, 'release') + runs-on: ubuntu-latest + + steps: + # 防止 GitHub HEAD 未同步 + - run: sleep 3 + + # 1️⃣ 获取分支 HEAD + - name: Get HEAD + id: head + run: | + HEAD_SHA=$(curl -s \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \ + | jq -r '.object.sha') + echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT + + # 2️⃣ 判断是否最终PR + - name: Check Latest + id: check + run: | + if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then + echo "ok=true" >> $GITHUB_OUTPUT + else + echo "ok=false" >> $GITHUB_OUTPUT + fi + + # 3️⃣ 尝试从 PR body 提取 Sourcery 摘要 + - name: Extract Sourcery Summary + if: steps.check.outputs.ok == 'true' + id: sourcery + env: + PR_BODY: ${{ github.event.pull_request.body }} + run: | + python3 << 'PYEOF' + import os, re + + body = os.environ.get("PR_BODY", "") or "" + match = re.search( + r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)", + body, + re.DOTALL + ) + + if match: + summary = match.group(1).strip() + found = "true" + else: + summary = "" + found = "false" + + with open("sourcery_summary.txt", "w", encoding="utf-8") as f: + f.write(summary) + + with open(os.environ["GITHUB_OUTPUT"], "a") as gh: + gh.write(f"found={found}\n") + gh.write("summary< commits.txt + + - name: AI Summary (Qwen Fallback) + if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false' + id: qwen + env: + DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }} + run: | + python3 << 'PYEOF' + import json, os, urllib.request + + with open("commits.txt", "r") as f: + commits = f.read().strip() + + prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits + payload = {"model": "qwen-plus", "input": {"prompt": prompt}} + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + + req = urllib.request.Request( + "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", + data=data, + headers={ + "Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"], + "Content-Type": "application/json" + } + ) + resp = urllib.request.urlopen(req) + result = json.loads(resp.read().decode()) + summary = result.get("output", {}).get("text", "AI 摘要生成失败") + + with open(os.environ["GITHUB_OUTPUT"], "a") as gh: + gh.write("summary< � **分支**: " + os.environ["BRANCH"] + "\n" + "> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n" + "> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n" + "> 🔢 **PR编号**: #" + pr_number + "\n" + "> 🔖 **Commit**: " + short_sha + "\n\n" + "### 🧠 " + label + "\n" + + summary + "\n\n" + "---\n" + "🔗 [查看PR详情](" + os.environ["PR_URL"] + ")" + ) + payload = {"msgtype": "markdown", "markdown": {"content": content}} + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + req = urllib.request.Request( + os.environ["WECHAT_WEBHOOK"], + data=data, + headers={"Content-Type": "application/json"} + ) + resp = urllib.request.urlopen(req) + print(resp.read().decode()) + PYEOF diff --git a/.github/workflows/sync-to-gitee.yml b/.github/workflows/sync-to-gitee.yml new file mode 100644 index 00000000..71ddf22a --- /dev/null +++ b/.github/workflows/sync-to-gitee.yml @@ -0,0 +1,36 @@ +name: Sync to Gitee + +on: + push: + branches: + - main # Production + - develop # Integration + - 'release/*' # Release preparation + - 'hotfix/*' # Urgent fixes + tags: + - '*' # All version tags (v1.0.0, etc.) + +jobs: + sync: + runs-on: ubuntu-latest + + steps: + - name: Checkout Source Code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Sync to Gitee + run: | + GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git" + git remote add gitee "$GITEE_URL" + + # 遍历并推送所有分支 + for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do + echo "Syncing branch: $branch" + git push -f gitee "origin/$branch:refs/heads/$branch" + done + + # 推送所有标签 + echo "Syncing tags..." + git push gitee --tags --force diff --git a/.gitignore b/.gitignore index ae3261f0..a1896da7 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ examples/ .kiro .vscode .idea +.claude # Temporary outputs .DS_Store @@ -26,6 +27,7 @@ time.log celerybeat-schedule.db search_results.json redbear-mem-metrics/ +redbear-mem-benchmark/ pitch-deck/ api/migrations/versions diff --git a/README.md b/README.md index 95d8d737..cd5be68b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ # MemoryBear empowers AI with human-like memory capabilities +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) +[![Python](https://img.shields.io/badge/Python-3.12+-green?logo=python&logoColor=white)](https://www.python.org/) +[![Gitee Sync](https://img.shields.io/github/actions/workflow/status/SuanmoSuanyangTechnology/MemoryBear/sync-to-gitee.yml?label=Gitee%20Sync&logo=gitee&logoColor=white)](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml) + [中文](./README_CN.md) | English ### [Installation Guide](#memorybear-installation-guide) diff --git a/README_CN.md b/README_CN.md index 1472acac..31ea718f 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,6 +2,10 @@ # MemoryBear 让AI拥有如同人类一样的记忆 +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) +[![Python](https://img.shields.io/badge/Python-3.12+-green?logo=python&logoColor=white)](https://www.python.org/) +[![Gitee Sync](https://img.shields.io/github/actions/workflow/status/SuanmoSuanyangTechnology/MemoryBear/sync-to-gitee.yml?label=Gitee%20Sync&logo=gitee&logoColor=white)](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml) + 中文 | [English](./README.md) ### [安装教程](#memorybear安装教程) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 23fd82ed..e44001d9 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -111,11 +111,17 @@ celery_app.conf.update( # Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题) 'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'}, + # Metadata extraction → memory_tasks queue + 'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, - 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, + # GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析) + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'}, + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, diff --git a/api/app/config/default_free_plan.py b/api/app/config/default_free_plan.py new file mode 100644 index 00000000..23a3a10e --- /dev/null +++ b/api/app/config/default_free_plan.py @@ -0,0 +1,30 @@ +""" +社区版默认免费套餐配置 +当无法从 SaaS 版获取 premium 模块时,使用此配置作为兜底 +""" + +DEFAULT_FREE_PLAN = { + "name": "记忆体验版", + "category": "saas_personal", + "tier_level": 0, + "version": "1.0", + "status": True, + "price": 0, + "billing_cycle": "permanent_free", + "core_value": "感受永久记忆", + "tech_support": "社群交流", + "sla_compliance": "无", + "page_customization": "无", + "theme_color": "#64748B", + "quotas": { + "workspace_quota": 1, + "skill_quota": 5, + "app_quota": 2, + "knowledge_capacity_quota": 0.3, + "memory_engine_quota": 1, + "end_user_quota": 1, + "ontology_project_quota": 3, + "model_quota": 1, + "api_ops_rate_limit": 50, + }, +} diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 50e9e0b0..377205c4 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -47,7 +47,8 @@ from . import ( user_memory_controllers, workspace_controller, ontology_controller, - skill_controller + skill_controller, + tenant_subscription_controller, ) # 创建管理端 API 路由器 @@ -98,5 +99,6 @@ manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) manager_router.include_router(i18n_controller.router) +manager_router.include_router(tenant_subscription_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 74991bcf..34449bb5 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -28,6 +28,7 @@ from app.services.app_statistics_service import AppStatisticsService from app.services.workflow_import_service import WorkflowImportService from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.app_dsl_service import AppDslService +from app.core.quota_stub import check_app_quota router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -35,6 +36,7 @@ logger = get_business_logger() @router.post("", summary="创建应用(可选创建 Agent 配置)") @cur_workspace_access_guard() +@check_app_quota def create_app( payload: app_schema.AppCreate, db: Session = Depends(get_db), @@ -292,10 +294,19 @@ def get_opening( ): """返回开场白文本和预设问题,供前端对话界面初始化时展示""" workspace_id = current_user.current_workspace_id - cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) - features = cfg.features or {} - if hasattr(features, "model_dump"): - features = features.model_dump() + + # 根据应用类型获取 features + from app.models.app_model import App as AppModel + app = db.get(AppModel, app_id) + if app and app.type == "workflow": + cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id) + features = cfg.features or {} + else: + cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) + features = cfg.features or {} + if hasattr(features, "model_dump"): + features = features.model_dump() + opening = features.get("opening_statement", {}) return success(data=app_schema.OpeningResponse( enabled=opening.get("enabled", False), @@ -1070,6 +1081,14 @@ async def update_workflow_config( current_user: Annotated[User, Depends(get_current_user)] ): workspace_id = current_user.current_workspace_id + if payload.variables: + from app.services.workflow_service import WorkflowService + resolved = await WorkflowService(db)._resolve_variables_file_defaults( + [v.model_dump() for v in payload.variables] + ) + # Patch default values back into VariableDefinition objects + for var_def, resolved_def in zip(payload.variables, resolved): + var_def.default = resolved_def.get("default", var_def.default) cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) return success(data=WorkflowConfigSchema.model_validate(cfg)) @@ -1233,9 +1252,11 @@ async def export_app( async def import_app( file: UploadFile = File(...), db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + app_id: Optional[str] = Form(None), ): """从 YAML 文件导入 agent / multi_agent / workflow 应用。 + 传入 app_id 时覆盖该应用的配置(类型必须一致),否则创建新应用。 跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。 """ if not file.filename.lower().endswith((".yaml", ".yml")): @@ -1246,13 +1267,15 @@ async def import_app( if not dsl or "app" not in dsl: return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST) - new_app, warnings = AppDslService(db).import_dsl( + target_app_id = uuid.UUID(app_id) if app_id else None + result_app, warnings = AppDslService(db).import_dsl( dsl=dsl, workspace_id=current_user.current_workspace_id, tenant_id=current_user.tenant_id, user_id=current_user.id, + app_id=target_app_id, ) return success( - data={"app": app_schema.App.model_validate(new_app), "warnings": warnings}, + data={"app": app_schema.App.model_validate(result_app), "warnings": warnings}, msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "") ) diff --git a/api/app/controllers/auth_controller.py b/api/app/controllers/auth_controller.py index 2cc72a3b..baae44a6 100644 --- a/api/app/controllers/auth_controller.py +++ b/api/app/controllers/auth_controller.py @@ -53,22 +53,24 @@ async def login_for_access_token( user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password) auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})") if form_data.invite: - auth_service.bind_workspace_with_invite(db=db, - user=user, - invite_token=form_data.invite, - workspace_id=invite_info.workspace_id) + auth_service.bind_workspace_with_invite( + db=db, + user=user, + invite_token=form_data.invite, + workspace_id=invite_info.workspace_id + ) except BusinessException as e: # 用户不存在且有邀请码,尝试注册 if e.code == BizCode.USER_NOT_FOUND: auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}") user = auth_service.register_user_with_invite( - db=db, - email=form_data.email, - username=form_data.username, - password=form_data.password, - invite_token=form_data.invite, - workspace_id=invite_info.workspace_id - ) + db=db, + email=form_data.email, + username=form_data.username, + password=form_data.password, + invite_token=form_data.invite, + workspace_id=invite_info.workspace_id + ) elif e.code == BizCode.PASSWORD_ERROR: # 用户存在但密码错误 auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}") @@ -134,7 +136,7 @@ async def refresh_token( # 检查用户是否存在 user = auth_service.get_user_by_id(db, userId) if not user: - raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS) # 检查 refresh token 黑名单 if settings.ENABLE_SINGLE_SESSION: diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 988aa706..cc1f8c98 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -23,6 +23,7 @@ from app.models.user_model import User from app.schemas import chunk_schema from app.schemas.response_schema import ApiResponse from app.services import knowledge_service, document_service, file_service, knowledgeshare_service +from app.services.model_service import ModelApiKeyService # Obtain a dedicated API logger api_logger = get_api_logger() @@ -442,10 +443,10 @@ async def retrieve_chunks( match retrieve_data.retrieve_type: case chunk_schema.RetrieveType.PARTICIPLE: rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) - return success(data=rs, msg="retrieval successful") + return success(data=jsonable_encoder(rs), msg="retrieval successful") case chunk_schema.RetrieveType.SEMANTIC: rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) - return success(data=rs, msg="retrieval successful") + return success(data=jsonable_encoder(rs), msg="retrieval successful") case _: rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) @@ -460,18 +461,20 @@ async def retrieve_chunks( if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: kb_ids = [str(kb_id) for kb_id in private_kb_ids] workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] + llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id) + emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id) # Prepare to configure chat_mdl、embedding_model、vision_model information chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base ) embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base ) - doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model) + doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) if doc: rs.insert(0, doc) return success(data=jsonable_encoder(rs), msg="retrieval successful") \ No newline at end of file diff --git a/api/app/controllers/document_controller.py b/api/app/controllers/document_controller.py index 72f9cb8f..350acc0e 100644 --- a/api/app/controllers/document_controller.py +++ b/api/app/controllers/document_controller.py @@ -314,8 +314,10 @@ async def parse_documents( ) # 4. Check if the file exists + api_logger.debug(f"Constructed file path: {file_path}") + api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}") if not os.path.exists(file_path): - api_logger.warning(f"File not found (possibly deleted): file_path={file_path}") + api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="File not found (possibly deleted)" diff --git a/api/app/controllers/file_controller.py b/api/app/controllers/file_controller.py index f7bd0e7a..6f8b1b97 100644 --- a/api/app/controllers/file_controller.py +++ b/api/app/controllers/file_controller.py @@ -19,6 +19,7 @@ from app.models.user_model import User from app.schemas import file_schema, document_schema from app.schemas.response_schema import ApiResponse from app.services import file_service, document_service +from app.core.quota_stub import check_knowledge_capacity_quota # Obtain a dedicated API logger @@ -131,6 +132,7 @@ async def create_folder( @router.post("/file", response_model=ApiResponse) +@check_knowledge_capacity_quota async def upload_file( kb_id: uuid.UUID, parent_id: uuid.UUID, diff --git a/api/app/controllers/home_page_controller.py b/api/app/controllers/home_page_controller.py index de4a78a3..400d155a 100644 --- a/api/app/controllers/home_page_controller.py +++ b/api/app/controllers/home_page_controller.py @@ -3,9 +3,10 @@ from sqlalchemy.orm import Session from app.core.config import settings from app.core.response_utils import success -from app.db import get_db +from app.db import get_db, SessionLocal from app.dependencies import get_current_user from app.models.user_model import User +from app.repositories.home_page_repository import HomePageRepository from app.schemas.response_schema import ApiResponse from app.services.home_page_service import HomePageService @@ -31,9 +32,32 @@ def get_workspace_list( @router.get("/version", response_model=ApiResponse) def get_system_version(): - """获取系统版本号+说明""" - current_version = settings.SYSTEM_VERSION - version_info = HomePageService.load_version_introduction(current_version) + """获取系统版本号 + 说明""" + current_version = None + version_info = None + + # 1️⃣ 优先从数据库获取最新已发布的版本 + try: + db = SessionLocal() + try: + current_version, version_info = HomePageRepository.get_latest_version_introduction(db) + finally: + db.close() + except Exception as e: + pass + + # 2️⃣ 降级:使用环境变量中的版本号 + if not current_version: + current_version = settings.SYSTEM_VERSION + version_info = HomePageService.load_version_introduction(current_version) + + # 3️⃣ 如果数据库和 JSON 都没有,返回基本信息 + if not version_info: + version_info = { + "introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}, + "introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []} + } + return success( data={ "version": current_version, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index 74b832cd..5cd87647 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -27,6 +27,7 @@ from app.schemas import knowledge_schema from app.schemas.response_schema import ApiResponse from app.services import knowledge_service, document_service from app.services.model_service import ModelConfigService +from app.core.quota_stub import check_knowledge_capacity_quota # Obtain a dedicated API logger api_logger = get_api_logger() @@ -179,6 +180,7 @@ async def get_knowledges( @router.post("/knowledge", response_model=ApiResponse) +@check_knowledge_capacity_quota async def create_knowledge( create_data: knowledge_schema.KnowledgeCreate, db: Session = Depends(get_db), @@ -352,6 +354,7 @@ async def delete_knowledge( # 2. Soft-delete knowledge base api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})") db_knowledge.status = 2 + db_knowledge.updated_at = datetime.datetime.now() db.commit() api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})") return success(msg="The knowledge base has been successfully deleted") diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index bedee987..525fe1eb 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -591,7 +591,7 @@ async def dashboard_data( "total_api_call": None } - # 1. 获取记忆总量(total_memory) + # 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点 try: total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count( db=db, @@ -600,49 +600,33 @@ async def dashboard_data( end_user_id=end_user_id ) neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0) - # total_app: 统计当前空间下的所有app数量 - # 包含自有app + 被分享给本工作空间的app - from app.services import app_service as _app_svc - _, total_app = _app_svc.AppService(db).list_apps( - workspace_id=workspace_id, include_shared=True, pagesize=1 - ) - neo4j_data["total_app"] = total_app - api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}") + api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}") except Exception as e: api_logger.warning(f"获取记忆总量失败: {str(e)}") - # 2. 获取知识库类型统计(total_knowledge) - try: - from app.services.memory_agent_service import MemoryAgentService - memory_agent_service = MemoryAgentService() - knowledge_stats = await memory_agent_service.get_knowledge_type_stats( - end_user_id=end_user_id, - only_active=True, - current_workspace_id=workspace_id, - db=db - ) - neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0) - api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}") - except Exception as e: - api_logger.warning(f"获取知识库类型统计失败: {str(e)}") + # 2. 获取共享统计数据(total_app、total_knowledge、total_api_call) + common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id) + neo4j_data.update(common_stats) + api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}") - # 3. 获取API调用统计(total_api_call) + # 计算昨日对比 try: - # 使用 AppStatisticsService 获取真实的API调用统计 - app_stats_service = AppStatisticsService(db) - api_stats = app_stats_service.get_workspace_api_statistics( + changes = memory_dashboard_service.get_dashboard_yesterday_changes( + db=db, workspace_id=workspace_id, - start_date=start_date, - end_date=end_date + storage_type=storage_type, + today_data=neo4j_data ) - # 计算总调用次数 - total_api_calls = sum(item.get("total_calls", 0) for item in api_stats) - neo4j_data["total_api_call"] = total_api_calls - api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}") + neo4j_data.update(changes) except Exception as e: - api_logger.error(f"获取API调用统计失败: {str(e)}") - neo4j_data["total_api_call"] = 0 - + api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}") + neo4j_data.update({ + "total_memory_change": None, + "total_app_change": None, + "total_knowledge_change": None, + "total_api_call_change": None, + }) + result["neo4j_data"] = neo4j_data api_logger.info("成功获取neo4j_data") @@ -655,44 +639,37 @@ async def dashboard_data( "total_api_call": None } - # 获取RAG相关数据 + # 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num try: - # total_memory: 只统计用户知识库(permission_id='Memory')的chunk数 total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user) rag_data["total_memory"] = total_chunk - - # total_app: 统计当前空间下的所有app数量 - # 包含自有app + 被分享给本工作空间的app - from app.services import app_service as _app_svc - _, total_app = _app_svc.AppService(db).list_apps( - workspace_id=workspace_id, include_shared=True, pagesize=1 - ) - rag_data["total_app"] = total_app - - # total_knowledge: 使用 total_kb(总知识库数) - total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) - rag_data["total_knowledge"] = total_kb - - # total_api_call: 使用 AppStatisticsService 获取真实的API调用统计 - try: - app_stats_service = AppStatisticsService(db) - api_stats = app_stats_service.get_workspace_api_statistics( - workspace_id=workspace_id, - start_date=start_date, - end_date=end_date - ) - # 计算总调用次数 - total_api_calls = sum(item.get("total_calls", 0) for item in api_stats) - rag_data["total_api_call"] = total_api_calls - api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}") - except Exception as e: - api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}") - rag_data["total_api_call"] = 0 - - api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") + api_logger.info(f"成功获取RAG记忆总量: {total_chunk}") except Exception as e: - api_logger.warning(f"获取RAG相关数据失败: {str(e)}") + api_logger.warning(f"获取RAG记忆总量失败: {str(e)}") + # 2. 获取共享统计数据(total_app、total_knowledge、total_api_call) + common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id) + rag_data.update(common_stats) + api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}") + + # 计算昨日对比 + try: + changes = memory_dashboard_service.get_dashboard_yesterday_changes( + db=db, + workspace_id=workspace_id, + storage_type=storage_type, + today_data=rag_data + ) + rag_data.update(changes) + except Exception as e: + api_logger.warning(f"计算RAG昨日对比失败: {str(e)}") + rag_data.update({ + "total_memory_change": None, + "total_app_change": None, + "total_knowledge_change": None, + "total_api_call_change": None, + }) + result["rag_data"] = rag_data api_logger.info("成功获取rag_data") diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d8b39325..545f8302 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -26,7 +26,7 @@ from app.services.memory_storage_service import ( analytics_hot_memory_tags, analytics_recent_activity_stats, kb_type_distribution, - search_all, + search_all_batch, search_chunk, search_detials, search_dialogue, @@ -34,6 +34,7 @@ from app.services.memory_storage_service import ( search_entity, search_statement, ) +from app.core.quota_stub import check_memory_engine_quota from fastapi import APIRouter, Depends, Header from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -76,6 +77,7 @@ async def get_storage_info( @router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@check_memory_engine_quota def create_config( payload: ConfigParamsCreate, current_user: User = Depends(get_current_user), @@ -409,7 +411,10 @@ async def search_all_num( ) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: - result = await search_all(end_user_id) + if not end_user_id: + return success(data={"total": 0}, msg="查询成功") + batch_result = await search_all_batch([end_user_id]) + result = {"total": batch_result.get(end_user_id, 0)} return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"Search all failed: {str(e)}") diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 71fd41ad..6105c3d8 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -15,6 +15,7 @@ from app.core.response_utils import success from app.schemas.response_schema import ApiResponse, PageData from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService from app.core.logging_config import get_api_logger +from app.core.quota_stub import check_model_quota, check_model_activation_quota # 获取API专用日志器 api_logger = get_api_logger() @@ -236,6 +237,7 @@ def delete_model_base( @router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse) +@check_model_quota def add_model_from_plaza( model_base_id: uuid.UUID, db: Session = Depends(get_db), @@ -273,6 +275,7 @@ def get_model_by_id( @router.post("", response_model=ApiResponse) +@check_model_quota async def create_model( model_data: model_schema.ModelConfigCreate, db: Session = Depends(get_db), @@ -303,6 +306,7 @@ async def create_model( @router.post("/composite", response_model=ApiResponse) +@check_model_quota async def create_composite_model( model_data: model_schema.CompositeModelCreate, db: Session = Depends(get_db), @@ -329,6 +333,7 @@ async def create_composite_model( @router.put("/composite/{model_id}", response_model=ApiResponse) +@check_model_activation_quota async def update_composite_model( model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, @@ -370,6 +375,7 @@ def delete_composite_model( @router.put("/{model_id}", response_model=ApiResponse) +@check_model_activation_quota def update_model( model_id: uuid.UUID, model_data: model_schema.ModelConfigUpdate, diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 3d2a1bdb..602ee709 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -28,6 +28,8 @@ from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, H from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session +from app.core.quota_stub import check_ontology_project_quota + from app.core.config import settings from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header @@ -163,6 +165,7 @@ def _get_ontology_service( api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, + capability=api_key_config.capability, max_retries=3, timeout=60.0 ) @@ -286,6 +289,7 @@ async def extract_ontology( # ==================== 本体场景管理接口 ==================== @router.post("/scene", response_model=ApiResponse) +@check_ontology_project_quota async def create_scene( request: SceneCreateRequest, db: Session = Depends(get_db), diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 80f14cd3..b9fc697c 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -124,10 +124,11 @@ async def get_prompt_opt( skill=data.skill ): # chunk 是 prompt 的增量内容 - yield f"event:message\ndata: {json.dumps(chunk)}\n\n" + yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n" except Exception as e: yield f"event:error\ndata: {json.dumps( - {"error": str(e)} + {"error": str(e)}, + ensure_ascii=False )}\n\n" yield "event:end\ndata: {}\n\n" diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index c10ad14b..ddd31071 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -453,6 +453,9 @@ async def chat( # 流式返回 agent_config = agent_config_4_app_release(release) + if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking): + agent_config.model_parameters["deep_thinking"] = False + if payload.stream: async def event_generator(): async for event in app_chat_service.agnet_chat_stream( @@ -634,7 +637,8 @@ async def config_query( "app_type": release.app.type, "variables": release.config.get("variables"), "memory": release.config.get("memory", {}).get("enabled"), - "features": release.config.get("features") + "features": release.config.get("features"), + "model_parameters": release.config.get("model_parameters") } elif release.app.type == AppType.MULTI_AGENT: content = { diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index d4573464..a78fd842 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -14,6 +14,7 @@ from app.core.response_utils import success from app.db import get_db from app.models.app_model import App from app.models.app_model import AppType +from app.models.app_release_model import AppRelease from app.repositories import knowledge_repository from app.repositories.end_user_repository import EndUserRepository from app.schemas import AppChatRequest, conversation_schema @@ -61,18 +62,18 @@ async def list_apps(): # return success(data={"received": True}, msg="消息已接收") -def _checkAppConfig(app: App): - if app.type == AppType.AGENT: - if not app.current_release.config: +def _checkAppConfig(release: AppRelease): + if release.type == AppType.AGENT: + if not release.config: raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) - elif app.type == AppType.MULTI_AGENT: - if not app.current_release.config: + elif release.type == AppType.MULTI_AGENT: + if not release.config: raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) - elif app.type == AppType.WORKFLOW: - if not app.current_release.config: + elif release.type == AppType.WORKFLOW: + if not release.config: raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING) else: - raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING) + raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED) @router.post("/chat") @@ -86,10 +87,22 @@ async def chat( app_service: Annotated[AppService, Depends(get_app_service)] = None, message: str = Body(..., description="聊天消息内容"), ): + """ + Agent/Workflow 聊天接口 + + - 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本) + - 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"} + """ body = await request.json() payload = AppChatRequest(**body) app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) + + # 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本 + if payload.version is not None: + active_release = app_service.get_release_by_id(app.id, payload.version) + else: + active_release = app.current_release other_id = payload.user_id workspace_id = api_key_auth.workspace_id end_user_repo = EndUserRepository(db) @@ -127,7 +140,7 @@ async def chat( storage_type = 'neo4j' app_type = app.type # check app config - _checkAppConfig(app) + _checkAppConfig(active_release) # 获取或创建会话(提前验证) conversation = conversation_service.create_or_get_conversation( @@ -142,8 +155,13 @@ async def chat( # print("="*50) # print(app.current_release.default_model_config_id) - agent_config = agent_config_4_app_release(app.current_release) + agent_config = agent_config_4_app_release(active_release) # print(agent_config.default_model_config_id) + + # thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用 + if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking): + agent_config.model_parameters["deep_thinking"] = False + # 流式返回 if payload.stream: async def event_generator(): @@ -189,7 +207,7 @@ async def chat( return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: # 多 Agent 流式返回 - config = multi_agent_config_4_app_release(app.current_release) + config = multi_agent_config_4_app_release(active_release) if payload.stream: async def event_generator(): async for event in app_chat_service.multi_agent_chat_stream( @@ -232,7 +250,7 @@ async def chat( return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.WORKFLOW: # 多 Agent 流式返回 - config = workflow_config_4_app_release(app.current_release) + config = workflow_config_4_app_release(active_release) if payload.stream: async def event_generator(): async for event in app_chat_service.workflow_chat_stream( @@ -248,7 +266,7 @@ async def chat( user_rag_memory_id=user_rag_memory_id, app_id=app.id, workspace_id=workspace_id, - release_id=app.current_release.id, + release_id=active_release.id, public=True ): event_type = event.get("event", "message") @@ -283,7 +301,7 @@ async def chat( files=payload.files, app_id=app.id, workspace_id=workspace_id, - release_id=app.current_release.id + release_id=active_release.id ) logger.debug( "工作流试运行返回结果", @@ -297,6 +315,4 @@ async def chat( msg="工作流任务执行成功" ) else: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py index a6dc224f..1faea6ef 100644 --- a/api/app/controllers/service/end_user_api_controller.py +++ b/api/app/controllers/service/end_user_api_controller.py @@ -10,6 +10,7 @@ from app.core.api_key_auth import require_api_key from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.core.quota_stub import check_end_user_quota from app.core.response_utils import success from app.db import get_db from app.repositories.end_user_repository import EndUserRepository @@ -41,6 +42,7 @@ def _get_current_user(api_key_auth: ApiKeyAuth, db: Session): @router.post("/create") @require_api_key(scopes=["memory"]) +@check_end_user_quota async def create_end_user( request: Request, api_key_auth: ApiKeyAuth = None, @@ -62,7 +64,7 @@ async def create_end_user( payload = CreateEndUserRequest(**body) workspace_id = api_key_auth.workspace_id - logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}") + logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id) # Resolve memory_config_id: explicit > workspace default memory_config_id = None diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 9acd865f..313781d2 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from app.core.api_key_auth import require_api_key from app.core.logging_config import get_business_logger +from app.core.quota_stub import check_end_user_quota from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth @@ -167,6 +168,7 @@ async def get_read_task_status( @router.post("/write/sync") @require_api_key(scopes=["memory"]) +@check_end_user_quota async def write_memory_sync( request: Request, api_key_auth: ApiKeyAuth = None, diff --git a/api/app/controllers/skill_controller.py b/api/app/controllers/skill_controller.py index 6e673679..4ee07c7d 100644 --- a/api/app/controllers/skill_controller.py +++ b/api/app/controllers/skill_controller.py @@ -11,11 +11,13 @@ from app.schemas import skill_schema from app.schemas.response_schema import PageData, PageMeta from app.services.skill_service import SkillService from app.core.response_utils import success +from app.core.quota_stub import check_skill_quota router = APIRouter(prefix="/skills", tags=["Skills"]) @router.post("", summary="创建技能") +@check_skill_quota def create_skill( data: skill_schema.SkillCreate, db: Session = Depends(get_db), diff --git a/api/app/controllers/tenant_subscription_controller.py b/api/app/controllers/tenant_subscription_controller.py new file mode 100644 index 00000000..c3fde572 --- /dev/null +++ b/api/app/controllers/tenant_subscription_controller.py @@ -0,0 +1,82 @@ +""" +租户套餐查询接口(普通用户可访问) +""" +import datetime +from typing import Callable + +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session + +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.db import get_db +from app.dependencies import get_current_user +from app.i18n.dependencies import get_translator +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse + +logger = get_api_logger() + +router = APIRouter(prefix="/tenant", tags=["Tenant"]) + + +@router.get("/subscription", response_model=ApiResponse, summary="获取当前用户所属租户的套餐信息") +async def get_my_tenant_subscription( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + t: Callable = Depends(get_translator), +): + """ + 获取当前登录用户所属租户的有效套餐订阅信息。 + 包含套餐名称、版本、配额、到期时间等。 + """ + try: + from premium.platform_admin.package_plan_service import TenantSubscriptionService + + if not current_user.tenant: + return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户")) + + tenant_id = current_user.tenant.id + svc = TenantSubscriptionService(db) + sub = svc.get_subscription(tenant_id) + + if not sub: + return success(data=None, msg="暂无有效套餐") + + return success(data=svc.build_response(sub)) + + except ModuleNotFoundError: + # 社区版无 premium 模块,从配置文件读取免费套餐 + if not current_user.tenant: + return JSONResponse(status_code=404, content=fail(code=404, msg="用户未关联租户")) + + from app.config.default_free_plan import DEFAULT_FREE_PLAN + + plan = DEFAULT_FREE_PLAN + response_data = { + "subscription_id": None, + "tenant_id": str(current_user.tenant.id), + "package_plan_id": None, + "package_version": plan["version"], + "package_plan": { + "id": None, + "name": plan["name"], + "version": plan["version"], + "category": plan["category"], + "tier_level": plan["tier_level"], + "price": float(plan["price"]), + "billing_cycle": plan["billing_cycle"], + }, + "started_at": None, + "expired_at": None, + "status": "active", + "quota": plan["quotas"], + "created_at": int(datetime.datetime.utcnow().timestamp() * 1000), + "updated_at": int(datetime.datetime.utcnow().timestamp() * 1000), + } + return success(data=response_data, msg="社区版免费套餐") + + except Exception as e: + logger.error(f"获取租户套餐信息失败: {e}", exc_info=True) + return JSONResponse(status_code=500, content=fail(code=500, msg="获取套餐信息失败")) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index cc16a6b4..5a329165 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -114,11 +114,14 @@ def get_current_user_info( # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限 if current_user.external_source: - from premium.sso.models import SSOSource - source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() - if source and source.permissions: - result_schema.permissions = source.permissions - else: + try: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() + if source and source.permissions: + result_schema.permissions = source.permissions + else: + result_schema.permissions = [] + except ModuleNotFoundError: result_schema.permissions = [] else: result_schema.permissions = ["all"] diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index 6f4a4fa8..47068288 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -35,6 +35,7 @@ from app.schemas.workspace_schema import ( WorkspaceUpdate, ) from app.services import workspace_service +from app.core.quota_stub import check_workspace_quota # 获取API专用日志器 api_logger = get_api_logger() @@ -106,6 +107,7 @@ def get_workspaces( @router.post("", response_model=ApiResponse) +@check_workspace_quota def create_workspace( workspace: WorkspaceCreate, language_type: str = Header(default="zh", alias="X-Language-Type"), diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 38821313..927eb734 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -14,6 +14,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from langgraph.errors import GraphRecursionError from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig @@ -37,7 +38,11 @@ class LangChainAgent: tools: Optional[Sequence[BaseTool]] = None, streaming: bool = False, max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) - max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 + max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数 + deep_thinking: bool = False, # 是否启用深度思考模式 + thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算 + json_output: bool = False, # 是否强制 JSON 输出 + capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考 ): """初始化 LangChain Agent @@ -75,6 +80,12 @@ class LangChainAgent: self.system_prompt = system_prompt or "你是一个专业的AI助手" + # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format + # 在 system prompt 中注入 JSON 要求 + from app.models.models_model import ModelProvider + if json_output and provider.lower() == ModelProvider.DASHSCOPE and not is_omni: + self.system_prompt += "\n请以JSON格式输出。" + logger.debug( f"Agent 迭代次数配置: max_iterations={self.max_iterations}, " f"tool_count={len(self.tools)}, " @@ -82,21 +93,28 @@ class LangChainAgent: f"auto_calculated={max_iterations is None}" ) - # 创建 RedBearLLM(支持多提供商) + # 创建 RedBearLLM,capability 校验由 RedBearModelConfig 统一处理 model_config = RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, base_url=api_base, is_omni=is_omni, + capability=capability, + deep_thinking=deep_thinking, + thinking_budget_tokens=thinking_budget_tokens, + json_output=json_output, extra_params={ "temperature": temperature, "max_tokens": max_tokens, - "streaming": streaming # 使用参数控制流式 + "streaming": streaming } ) self.llm = RedBearLLM(model_config, type=ModelType.CHAT) + # 从经过校验的 config 读取实际生效的能力开关 + self.deep_thinking = model_config.deep_thinking + self.json_output = model_config.json_output # 获取底层模型用于真正的流式调用 self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm @@ -249,6 +267,33 @@ class LangChainAgent: return messages + @staticmethod + def _extract_tokens_from_message(msg) -> int: + """从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式 + + 支持的格式: + - response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI) + - response_metadata.usage.total_tokens (部分 provider) + - usage_metadata.total_tokens (LangChain 新版) + """ + total = 0 + # 1. response_metadata + response_meta = getattr(msg, "response_metadata", None) + if response_meta and isinstance(response_meta, dict): + # 尝试 token_usage 路径 + token_usage = response_meta.get("token_usage") or response_meta.get("usage", {}) + if isinstance(token_usage, dict): + total = token_usage.get("total_tokens", 0) + # 2. usage_metadata(LangChain 新版 AIMessage 属性) + if not total: + usage_meta = getattr(msg, "usage_metadata", None) + if usage_meta: + if isinstance(usage_meta, dict): + total = usage_meta.get("total_tokens", 0) + else: + total = getattr(usage_meta, "total_tokens", 0) + return total or 0 + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -283,6 +328,17 @@ class LangChainAgent: return content_parts + @staticmethod + def _extract_reasoning_content(msg) -> str: + """从 AIMessage 中提取深度思考内容(reasoning_content) + + 所有 provider 统一通过 additional_kwargs.reasoning_content 传递: + - DeepSeek-R1 / QwQ: 原生字段 + - Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入 + """ + additional = getattr(msg, "additional_kwargs", None) or {} + return additional.get("reasoning_content") or additional.get("reasoning", "") + async def chat( self, message: str, @@ -325,7 +381,7 @@ class LangChainAgent: {"messages": messages}, config={"recursion_limit": self.max_iterations} ) - except RecursionError as e: + except (RecursionError, GraphRecursionError) as e: logger.warning( f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环", extra={"error": str(e)} @@ -348,6 +404,7 @@ class LangChainAgent: logger.debug(f"输出消息数量: {len(output_messages)}") total_tokens = 0 + reasoning_content = "" for msg in reversed(output_messages): if isinstance(msg, AIMessage): logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}") @@ -382,8 +439,8 @@ class LangChainAgent: else: content = str(msg.content) logger.debug(f"转换为字符串: {content[:100]}...") - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 + total_tokens = self._extract_tokens_from_message(msg) + reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else "" break logger.info(f"最终提取的内容长度: {len(content)}") @@ -399,6 +456,8 @@ class LangChainAgent: "total_tokens": total_tokens } } + if reasoning_content: + response["reasoning_content"] = reasoning_content logger.debug( "Agent 调用完成", @@ -420,7 +479,7 @@ class LangChainAgent: history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[str | int | dict[str, str], None]: """执行流式对话 Args: @@ -431,6 +490,8 @@ class LangChainAgent: Yields: str: 消息内容块 + int: token 统计 + Dict: 深度思考内容 {"type": "reasoning", "content": "..."} """ logger.info("=" * 80) logger.info(" chat_stream 方法开始执行") @@ -451,6 +512,7 @@ class LangChainAgent: # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") full_content = '' + full_reasoning = '' try: last_event = {} async for event in self.agent.astream_events( @@ -467,6 +529,13 @@ class LangChainAgent: # LLM 流式输出 chunk = event.get("data", {}).get("chunk") if chunk and hasattr(chunk, "content"): + # 提取深度思考内容(仅在启用深度思考时) + if self.deep_thinking: + reasoning_chunk = self._extract_reasoning_content(chunk) + if reasoning_chunk: + full_reasoning += reasoning_chunk + yield {"type": "reasoning", "content": reasoning_chunk} + # 处理多模态响应:content 可能是字符串或列表 chunk_content = chunk.content if isinstance(chunk_content, str) and chunk_content: @@ -497,6 +566,13 @@ class LangChainAgent: chunk = event.get("data", {}).get("chunk") if chunk: if hasattr(chunk, "content"): + # 提取深度思考内容(仅在启用深度思考时) + if self.deep_thinking: + reasoning_chunk = self._extract_reasoning_content(chunk) + if reasoning_chunk: + full_reasoning += reasoning_chunk + yield {"type": "reasoning", "content": reasoning_chunk} + chunk_content = chunk.content if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content @@ -535,14 +611,17 @@ class LangChainAgent: output_messages = last_event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get( - "total_tokens", - 0 - ) if response_meta else 0 - yield total_tokens + stream_total_tokens = self._extract_tokens_from_message(msg) + logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}") + yield stream_total_tokens break + except GraphRecursionError: + logger.warning( + f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断" + ) + if not full_content: + yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。" except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index 342405b8..91d6bd8a 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -96,6 +96,38 @@ def require_api_key( resource_id=api_key_obj.resource_id, ) + # ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)────────── + try: + from app.models.workspace_model import Workspace + from premium.platform_admin.package_plan_service import TenantSubscriptionService + + workspace = db.query(Workspace).filter( + Workspace.id == api_key_obj.workspace_id + ).first() + if workspace: + quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id) + tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None + if tenant_qps_limit: + rate_limiter = RateLimiterService() + tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit( + workspace.tenant_id, tenant_qps_limit + ) + if not tenant_ok: + raise RateLimitException( + "租户 API 调用速率超限", + BizCode.API_KEY_QPS_LIMIT_EXCEEDED, + rate_headers={ + "X-RateLimit-Tenant-Limit": str(tenant_info["limit"]), + "X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]), + "X-RateLimit-Tenant-Reset": str(tenant_info["reset"]), + } + ) + except RateLimitException: + raise + except Exception as e: + logger.warning(f"Tenant 限速检查异常,跳过: {e}") + # ───────────────────────────────────────────────────────────── + rate_limiter = RateLimiterService() is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj) if not is_allowed: diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index 3feae4f6..01b6115d 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -19,6 +19,7 @@ class BizCode(IntEnum): TENANT_NOT_FOUND = 3002 WORKSPACE_NO_ACCESS = 3003 WORKSPACE_INVITE_NOT_FOUND = 3004 + WORKSPACE_ACCESS_DENIED = 3005 # API Key 管理(3xxx) API_KEY_NOT_FOUND = 3007 API_KEY_DUPLICATE_NAME = 3008 @@ -40,6 +41,7 @@ class BizCode(IntEnum): FILE_NOT_FOUND = 4006 APP_NOT_FOUND = 4007 RELEASE_NOT_FOUND = 4008 + USER_NO_ACCESS = 4009 # 冲突/状态(5xxx) DUPLICATE_NAME = 5001 @@ -113,8 +115,11 @@ HTTP_MAPPING = { BizCode.FORBIDDEN: 403, BizCode.TENANT_NOT_FOUND: 400, BizCode.WORKSPACE_NO_ACCESS: 403, + BizCode.WORKSPACE_INVITE_NOT_FOUND: 400, + BizCode.WORKSPACE_ACCESS_DENIED: 403, BizCode.NOT_FOUND: 400, BizCode.USER_NOT_FOUND: 200, + BizCode.USER_NO_ACCESS: 401, BizCode.WORKSPACE_NOT_FOUND: 400, BizCode.MODEL_NOT_FOUND: 400, BizCode.KNOWLEDGE_NOT_FOUND: 400, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py new file mode 100644 index 00000000..1cf5e291 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py @@ -0,0 +1,408 @@ +""" +Perceptual Memory Retrieval Node & Service + +Provides PerceptualSearchService for searching perceptual memories (vision, audio, +text, conversation) from Neo4j using keyword fulltext + embedding semantic search +with BM25+embedding fusion reranking. + +Also provides the perceptual_retrieve_node for use as a LangGraph node. +""" +import asyncio +import math +from typing import List, Dict, Any, Optional + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import ReadState +from app.core.memory.utils.data.text_utils import escape_lucene_query +from app.repositories.neo4j.graph_search import ( + search_perceptual, + search_perceptual_by_embedding, +) +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = get_agent_logger(__name__) + + +class PerceptualSearchService: + """ + 感知记忆检索服务。 + + 封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。 + 调用方只需提供 query / keywords、end_user_id、memory_config,即可获得 + 格式化并排序后的感知记忆列表和拼接文本。 + + Usage: + service = PerceptualSearchService(end_user_id=..., memory_config=...) + results = await service.search(query="...", keywords=[...], limit=10) + # results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M} + """ + + DEFAULT_ALPHA = 0.6 + DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5 + + def __init__( + self, + end_user_id: str, + memory_config: Any, + alpha: float = DEFAULT_ALPHA, + content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD, + ): + self.end_user_id = end_user_id + self.memory_config = memory_config + self.alpha = alpha + self.content_score_threshold = content_score_threshold + + async def search( + self, + query: str, + keywords: Optional[List[str]] = None, + limit: int = 10, + ) -> Dict[str, Any]: + """ + 执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。 + + 对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数, + 确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。 + + Args: + query: 原始用户查询(用于向量检索和 BM25 补查) + keywords: 关键词列表(用于全文检索),为 None 时使用 [query] + limit: 最大返回数量 + + Returns: + { + "memories": [格式化后的记忆 dict, ...], + "content": "拼接的纯文本摘要", + "keyword_raw": int, + "embedding_raw": int, + } + """ + if keywords is None: + keywords = [query] if query else [] + + connector = Neo4jConnector() + try: + kw_task = self._keyword_search(connector, keywords, limit) + emb_task = self._embedding_search(connector, query, limit) + + kw_results, emb_results = await asyncio.gather( + kw_task, emb_task, return_exceptions=True + ) + if isinstance(kw_results, Exception): + logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}") + kw_results = [] + if isinstance(emb_results, Exception): + logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}") + emb_results = [] + + # 补查 BM25:找出 embedding 命中但 keyword 未命中的 id, + # 用原始 query 对这些节点补查全文索引拿 BM25 score + kw_ids = {r.get("id") for r in kw_results if r.get("id")} + emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids} + + if emb_only_ids and query: + backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit) + # 把补查到的 BM25 score 注入到 embedding 结果中 + backfill_map = {r["id"]: r.get("score", 0) for r in backfill} + for r in emb_results: + rid = r.get("id", "") + if rid in backfill_map: + r["bm25_backfill_score"] = backfill_map[rid] + logger.info( + f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, " + f"{len(backfill_map)} got BM25 scores" + ) + + reranked = self._rerank(kw_results, emb_results, limit) + + memories = [] + content_parts = [] + for record in reranked: + fmt = self._format_result(record) + fmt["score"] = round(record.get("content_score", 0), 4) + memories.append(fmt) + content_parts.append(self._build_content_text(fmt)) + + logger.info( + f"[PerceptualSearch] {len(memories)} results after rerank " + f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})" + ) + return { + "memories": memories, + "content": "\n\n".join(content_parts), + "keyword_raw": len(kw_results), + "embedding_raw": len(emb_results), + } + finally: + await connector.close() + + async def _bm25_backfill( + self, + connector: Neo4jConnector, + query: str, + target_ids: set, + limit: int, + ) -> List[dict]: + """ + 对指定 id 集合补查全文索引 BM25 score。 + + 用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。 + """ + escaped = escape_lucene_query(query) + if not escaped.strip(): + return [] + try: + r = await search_perceptual( + connector=connector, query=escaped, + end_user_id=self.end_user_id, + limit=limit * 5, # 多查一些以提高命中率 + ) + all_hits = r.get("perceptuals", []) + return [h for h in all_hits if h.get("id") in target_ids] + except Exception as e: + logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}") + return [] + + async def _keyword_search( + self, + connector: Neo4jConnector, + keywords: List[str], + limit: int, + ) -> List[dict]: + """并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。""" + seen_ids: set = set() + all_results: List[dict] = [] + + async def _one(kw: str): + escaped = escape_lucene_query(kw) + if not escaped.strip(): + return [] + r = await search_perceptual( + connector=connector, query=escaped, + end_user_id=self.end_user_id, limit=limit, + ) + return r.get("perceptuals", []) + + tasks = [_one(kw) for kw in keywords[:10]] + batch = await asyncio.gather(*tasks, return_exceptions=True) + + for result in batch: + if isinstance(result, Exception): + logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}") + continue + for rec in result: + rid = rec.get("id", "") + if rid and rid not in seen_ids: + seen_ids.add(rid) + all_results.append(rec) + + all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True) + return all_results[:limit] + + async def _embedding_search( + self, + connector: Neo4jConnector, + query_text: str, + limit: int, + ) -> List[dict]: + """向量语义检索,返回原始结果(不做阈值过滤)。""" + try: + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.core.models.base import RedBearModelConfig + from app.db import get_db_context + from app.services.memory_config_service import MemoryConfigService + + with get_db_context() as db: + cfg = MemoryConfigService(db).get_embedder_config( + str(self.memory_config.embedding_model_id) + ) + client = OpenAIEmbedderClient(RedBearModelConfig(**cfg)) + + r = await search_perceptual_by_embedding( + connector=connector, embedder_client=client, + query_text=query_text, end_user_id=self.end_user_id, + limit=limit, + ) + return r.get("perceptuals", []) + except Exception as e: + logger.warning(f"[PerceptualSearch] embedding search failed: {e}") + return [] + + def _rerank( + self, + keyword_results: List[dict], + embedding_results: List[dict], + limit: int, + ) -> List[dict]: + """BM25 + embedding 融合排序。 + + 对 embedding 结果中带有 bm25_backfill_score 的条目, + 将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。 + """ + # 把补查的 BM25 score 合并到 keyword_results 中统一归一化 + emb_backfill_items = [] + for item in embedding_results: + backfill_score = item.get("bm25_backfill_score") + if backfill_score is not None and item.get("id"): + emb_backfill_items.append({"id": item["id"], "score": backfill_score}) + + # 合并后统一归一化 BM25 scores + all_bm25_items = keyword_results + emb_backfill_items + all_bm25_items = self._normalize_scores(all_bm25_items) + + # 建立 id -> normalized BM25 score 的映射 + bm25_norm_map: Dict[str, float] = {} + for item in all_bm25_items: + item_id = item.get("id", "") + if item_id: + bm25_norm_map[item_id] = float(item.get("normalized_score", 0)) + + # 归一化 embedding scores + embedding_results = self._normalize_scores(embedding_results) + + # 合并 + combined: Dict[str, dict] = {} + for item in keyword_results: + item_id = item.get("id", "") + if not item_id: + continue + combined[item_id] = item.copy() + combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0) + combined[item_id]["embedding_score"] = 0.0 + + for item in embedding_results: + item_id = item.get("id", "") + if not item_id: + continue + if item_id in combined: + combined[item_id]["embedding_score"] = item.get("normalized_score", 0) + else: + combined[item_id] = item.copy() + combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0) + combined[item_id]["embedding_score"] = item.get("normalized_score", 0) + + for item in combined.values(): + bm25 = float(item.get("bm25_score", 0) or 0) + emb = float(item.get("embedding_score", 0) or 0) + item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb + + results = list(combined.values()) + before = len(results) + results = [r for r in results if r["content_score"] >= self.content_score_threshold] + results.sort(key=lambda x: x["content_score"], reverse=True) + results = results[:limit] + + logger.info( + f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} " + f"(alpha={self.alpha}, threshold={self.content_score_threshold})" + ) + return results + + @staticmethod + def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]: + """Z-score + sigmoid 归一化。""" + if not items: + return items + scores = [float(it.get(field, 0) or 0) for it in items] + if len(scores) <= 1: + for it in items: + it[f"normalized_{field}"] = 1.0 + return items + mean = sum(scores) / len(scores) + var = sum((s - mean) ** 2 for s in scores) / len(scores) + std = math.sqrt(var) + if std == 0: + for it in items: + it[f"normalized_{field}"] = 1.0 + else: + for it, s in zip(items, scores): + z = (s - mean) / std + it[f"normalized_{field}"] = 1 / (1 + math.exp(-z)) + return items + + @staticmethod + def _format_result(record: dict) -> dict: + return { + "id": record.get("id", ""), + "perceptual_type": record.get("perceptual_type", ""), + "file_name": record.get("file_name", ""), + "file_path": record.get("file_path", ""), + "summary": record.get("summary", ""), + "topic": record.get("topic", ""), + "domain": record.get("domain", ""), + "keywords": record.get("keywords", []), + "created_at": str(record.get("created_at", "")), + "file_type": record.get("file_type", ""), + "score": record.get("score", 0), + } + + @staticmethod + def _build_content_text(formatted: dict) -> str: + parts = [] + if formatted["summary"]: + parts.append(formatted["summary"]) + if formatted["topic"]: + parts.append(f"[主题: {formatted['topic']}]") + if formatted["keywords"]: + kw_list = formatted["keywords"] + if isinstance(kw_list, list): + parts.append(f"[关键词: {', '.join(kw_list)}]") + if formatted["file_name"]: + parts.append(f"[文件: {formatted['file_name']}]") + return " ".join(parts) + + +def _extract_keywords_from_problems(problem_extension: dict) -> List[str]: + """Extract search keywords from problem extension results.""" + keywords = [] + context = problem_extension.get("context", {}) + if isinstance(context, dict): + for original_q, extended_qs in context.items(): + keywords.append(original_q) + if isinstance(extended_qs, list): + keywords.extend(extended_qs) + return keywords + + +async def perceptual_retrieve_node(state: ReadState) -> ReadState: + """ + LangGraph node: perceptual memory retrieval. + + Uses PerceptualSearchService to run keyword + embedding search with + BM25 fusion reranking, then writes results to state['perceptual_data']. + """ + end_user_id = state.get("end_user_id", "") + problem_extension = state.get("problem_extension", {}) + original_query = state.get("data", "") + memory_config = state.get("memory_config", None) + + logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}") + + keywords = _extract_keywords_from_problems(problem_extension) + if not keywords: + keywords = [original_query] if original_query else [] + + logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted") + + service = PerceptualSearchService( + end_user_id=end_user_id, + memory_config=memory_config, + ) + search_result = await service.search( + query=original_query, + keywords=keywords, + limit=10, + ) + + result = { + "memories": search_result["memories"], + "content": search_result["content"], + "_intermediate": { + "type": "perceptual_retrieve", + "title": "感知记忆检索", + "data": search_result["memories"], + "query": original_query, + "result_count": len(search_result["memories"]), + }, + } + return {"perceptual_data": result} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 3030669c..2d6eaa81 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState: logger.info(f"Problem extension result: {aggregated_dict}") # Emit intermediate output for frontend - print(time.time() - start) result = { "context": aggregated_dict, "original": data, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index d967a285..1bf68966 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,7 +1,11 @@ +import asyncio import os import time from app.core.logging_config import get_agent_logger, log_time +from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import ( + PerceptualSearchService, +) from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, SummaryResponse, @@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState: try: if storage_type != "rag": - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search( + + async def _perceptual_search(): + service = PerceptualSearchService( + end_user_id=end_user_id, + memory_config=memory_config, + ) + return await service.search(query=data, limit=5) + + hybrid_task = SearchService().execute_hybrid_search( **search_params, memory_config=memory_config, - expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement + expand_communities=False, ) + perceptual_task = _perceptual_search() + + gather_results = await asyncio.gather( + hybrid_task, perceptual_task, return_exceptions=True + ) + hybrid_result = gather_results[0] + perceptual_results = gather_results[1] + + # 处理 hybrid search 异常 + if isinstance(hybrid_result, Exception): + raise hybrid_result + retrieve_info, question, raw_results = hybrid_result + + # 处理感知记忆结果 + if isinstance(perceptual_results, Exception): + logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}") + perceptual_results = [] + + # 拼接感知记忆内容到 retrieve_info + if perceptual_results and isinstance(perceptual_results, dict): + perceptual_content = perceptual_results.get("content", "") + if perceptual_content: + retrieve_info = f"{retrieve_info}\n\n\n{perceptual_content}" + count = len(perceptual_results.get("memories", [])) + logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)") + # 调试:打印 community 检索结果数量 if raw_results and isinstance(raw_results, dict): reranked = raw_results.get('reranked_results', {}) @@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState: "error": str(e) } end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 + duration = end - start log_time('检索', duration) return {"summary": summary} @@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState: retrieve_info_str = list(set(retrieve_info_str)) retrieve_info_str = '\n'.join(retrieve_info_str) - aimessages = await summary_llm(state, history, retrieve_info_str, - 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + + aimessages = await summary_llm( + state, + history, + retrieve_info_str, + 'direct_summary_prompt.jinja2', + 'retrieve_summary', RetrieveSummaryResponse, + "1" + ) if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState: retrieve_info_str += i + '\n' history = await summary_history(state) + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + data = { "query": query, "history": history, @@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState: if key == 'answer_small': for i in value: retrieve_info_str += i + '\n' + + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + data = { "query": query, "history": history, diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index e698e6ad..d3ca4ea7 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( retrieve_nodes, ) +from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import ( + perceptual_retrieve_node, +) from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( Input_Summary, Retrieve_Summary, @@ -48,13 +51,14 @@ async def make_read_graph(): """ try: # Build workflow graph - workflow = StateGraph(ReadState) + workflow = StateGraph(ReadState) workflow.add_node("content_input", content_input_node) workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Input_Summary", Input_Summary) workflow.add_node("Retrieve", retrieve_nodes) # workflow.add_node("Retrieve", retrieve) + workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node) workflow.add_node("Verify", Verify) workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) @@ -65,14 +69,15 @@ async def make_read_graph(): workflow.add_conditional_edges("content_input", Split_continue) workflow.add_edge("Input_Summary", END) workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") + # After Problem_Extension, retrieve perceptual memory first, then main Retrieve + workflow.add_edge("Problem_Extension", "Perceptual_Retrieve") + workflow.add_edge("Perceptual_Retrieve", "Retrieve") workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) - '''-----''' # workflow.add_edge("Retrieve", END) # Compile workflow @@ -80,7 +85,5 @@ async def make_read_graph(): yield graph except Exception as e: - print(f"创建工作流失败: {e}") + logger.error(f"创建工作流失败: {e}") raise - finally: - print("工作流创建完成") diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 90b1c088..eaa5f0ab 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger from app.core.memory.src.search import run_hybrid_search from app.core.memory.utils.data.text_utils import escape_lucene_query - logger = get_agent_logger(__name__) # 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化) @@ -31,10 +30,10 @@ def _clean_expand_fields(obj): async def expand_communities_to_statements( - community_results: List[dict], - end_user_id: str, - existing_content: str = "", - limit: int = 10, + community_results: List[dict], + end_user_id: str, + existing_content: str = "", + limit: int = 10, ) -> Tuple[List[dict], List[str]]: """ 社区展开 helper:给定命中的 community 列表,拉取关联 Statement。 @@ -76,17 +75,18 @@ async def expand_communities_to_statements( if s.get("statement") and s["statement"] not in existing_lines ] cleaned = _clean_expand_fields(expanded_stmts) - logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}") + logger.info( + f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}") return cleaned, new_texts class SearchService: """Service for executing hybrid search and processing results.""" - + def __init__(self): """Initialize the search service.""" logger.info("SearchService initialized") - + def extract_content_from_result(self, result: dict, node_type: str = "") -> str: """ Extract only meaningful content from search results, dropping all metadata. @@ -107,19 +107,19 @@ class SearchService: """ if not isinstance(result, dict): return str(result) - + content_parts = [] - + # Statements: extract statement field if 'statement' in result and result['statement']: content_parts.append(result['statement']) - + # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 is_community = ( - node_type == "community" - or 'member_count' in result - or 'core_entities' in result + node_type == "community" + or 'member_count' in result + or 'core_entities' in result ) if is_community: name = result.get('name', '') @@ -130,16 +130,16 @@ class SearchService: elif 'content' in result and result['content']: # Summaries / Chunks content_parts.append(result['content']) - + # Entities: extract name and fact_summary (commented out in original) # if 'name' in result and result['name']: # content_parts.append(result['name']) # if result.get('fact_summary'): # content_parts.append(result['fact_summary']) - + # Return concatenated content or empty string return '\n'.join(content_parts) if content_parts else "" - + def clean_query(self, query: str) -> str: """ Clean and escape query text for Lucene. @@ -155,33 +155,33 @@ class SearchService: Cleaned and escaped query string """ q = str(query).strip() - + # Remove wrapping quotes if (q.startswith("'") and q.endswith("'")) or ( - q.startswith('"') and q.endswith('"') + q.startswith('"') and q.endswith('"') ): q = q[1:-1] - + # Remove newlines and carriage returns q = q.replace('\r', ' ').replace('\n', ' ').strip() - + # Apply Lucene escaping q = escape_lucene_query(q) - + return q - + async def execute_hybrid_search( - self, - end_user_id: str, - question: str, - limit: int = 5, - search_type: str = "hybrid", - include: Optional[List[str]] = None, - rerank_alpha: float = 0.4, - output_path: str = "search_results.json", - return_raw_results: bool = False, - memory_config = None, - expand_communities: bool = True, + self, + end_user_id: str, + question: str, + limit: int = 5, + search_type: str = "hybrid", + include: Optional[List[str]] = None, + rerank_alpha: float = 0.4, + output_path: str = "search_results.json", + return_raw_results: bool = False, + memory_config=None, + expand_communities: bool = True, ) -> Tuple[str, str, Optional[dict]]: """ Execute hybrid search and return clean content. @@ -205,10 +205,10 @@ class SearchService: """ if include is None: include = ["statements", "chunks", "entities", "summaries", "communities"] - + # Clean query cleaned_query = self.clean_query(question) - + try: # Execute search answer = await run_hybrid_search( @@ -221,18 +221,18 @@ class SearchService: memory_config=memory_config, rerank_alpha=rerank_alpha ) - + # Extract results based on search type and include parameter # Prioritize summaries as they contain synthesized contextual information answer_list = [] - + # For hybrid search, use reranked_results if search_type == "hybrid": reranked_results = answer.get('reranked_results', {}) - + # Priority order: summaries first (most contextual), then communities, statements, chunks, entities priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] - + for category in priority_order: if category in include and category in reranked_results: category_results = reranked_results[category] @@ -242,7 +242,7 @@ class SearchService: # For keyword or embedding search, results are directly in answer dict # Apply same priority order priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] - + for category in priority_order: if category in include and category in answer: category_results = answer[category] @@ -261,7 +261,7 @@ class SearchService: end_user_id=end_user_id, ) answer_list.extend(cleaned_stmts) - + # Extract clean content from all results,按类型传入 node_type 区分 community content_list = [] for ans in answer_list: @@ -269,19 +269,18 @@ class SearchService: ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" content_list.append(self.extract_content_from_result(ans, node_type=ntype)) - # Filter out empty strings and join with newlines clean_content = '\n'.join([c for c in content_list if c]) - + # Log first 200 chars logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...") - + # Return raw results if requested if return_raw_results: return clean_content, cleaned_query, answer else: return clean_content, cleaned_query, None - + except Exception as e: logger.error( f"Search failed for query '{question}' in group '{end_user_id}': {e}", diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index ea8add48..21bc1777 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,4 +1,3 @@ -import os from collections import defaultdict from pathlib import Path from typing import Annotated, TypedDict @@ -52,6 +51,7 @@ class ReadState(TypedDict): embedding_id: str memory_config: object # 新增字段用于传递内存配置对象 retrieve: dict + perceptual_data: dict RetrieveSummary: dict InputSummary: dict verify: dict diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 1f437973..3b0ea1ee 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -14,6 +14,7 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ memory_summary_generation @@ -152,6 +153,24 @@ async def write( # Step 3: Save all data to Neo4j database step_start = time.time() + # Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染 + # 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并, + # 确保用户实体的 aliases 不包含 AI 助手的名字 + try: + from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( + clean_cross_role_aliases, + fetch_neo4j_assistant_aliases, + ) + neo4j_assistant_aliases = set() + if all_entity_nodes: + _eu_id = all_entity_nodes[0].end_user_id + if _eu_id: + neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id) + clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases) + logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}") + except Exception as e: + logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}") + # 添加死锁重试机制 max_retries = 3 retry_delay = 1 # 秒 @@ -173,15 +192,37 @@ async def write( if success: logger.info("Successfully saved all data to Neo4j") - # 使用 Celery 异步任务触发聚类(不阻塞主流程) if all_entity_nodes: + end_user_id = all_entity_nodes[0].end_user_id + + # Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体 + try: + from app.repositories.end_user_info_repository import EndUserInfoRepository + if end_user_id: + with get_db_context() as db_session: + info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id)) + pg_aliases = info.aliases if info and info.aliases else [] + if info is not None: + # 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码 + placeholder_names = list(_USER_PLACEHOLDER_NAMES) + await neo4j_connector.execute_query( + """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names + SET e.aliases = $aliases + """, + end_user_id=end_user_id, aliases=pg_aliases, + placeholder_names=placeholder_names, + ) + logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}") + except Exception as sync_err: + logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}") + + # 使用 Celery 异步任务触发聚类(不阻塞主流程) try: from app.tasks import run_incremental_clustering - end_user_id = all_entity_nodes[0].end_user_id new_entity_ids = [e.id for e in all_entity_nodes] - - # 异步提交 Celery 任务 task = run_incremental_clustering.apply_async( kwargs={ "end_user_id": end_user_id, @@ -189,7 +230,6 @@ async def write( "llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None, "embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, }, - # 设置任务优先级(低优先级,不影响主业务) priority=3, ) logger.info( @@ -197,7 +237,6 @@ async def write( f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}" ) except Exception as e: - # 聚类任务提交失败不影响主流程 logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True) break diff --git a/api/app/core/memory/models/__init__.py b/api/app/core/memory/models/__init__.py index 41d08908..eed8e8c4 100644 --- a/api/app/core/memory/models/__init__.py +++ b/api/app/core/memory/models/__init__.py @@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import ( TripletExtractionResponse, ) +# User metadata models +from app.core.memory.models.metadata_models import ( + UserMetadata, + UserMetadataBehavioralHints, + UserMetadataProfile, + MetadataExtractionResponse, +) + # Ontology scenario models (LLM extracted from scenarios) from app.core.memory.models.ontology_scenario_models import ( OntologyClass, @@ -124,6 +132,10 @@ __all__ = [ "Entity", "Triplet", "TripletExtractionResponse", + "UserMetadata", + "UserMetadataBehavioralHints", + "UserMetadataProfile", + "MetadataExtractionResponse", # Ontology models "OntologyClass", "OntologyExtractionResponse", diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1b8c9d52..6e34421c 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -364,12 +364,14 @@ class ChunkNode(Node): Attributes: dialog_id: ID of the parent dialog content: The text content of the chunk + speaker: Speaker identifier ('user' or 'assistant') chunk_embedding: Optional embedding vector for the chunk sequence_number: Order of this chunk within the dialog metadata: Additional chunk metadata as key-value pairs """ dialog_id: str = Field(..., description="ID of the parent dialog") content: str = Field(..., description="The text content of the chunk") + speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") sequence_number: int = Field(..., description="Order of this chunk within the dialog") metadata: dict = Field(default_factory=dict, description="Additional chunk metadata") diff --git a/api/app/core/memory/models/metadata_models.py b/api/app/core/memory/models/metadata_models.py new file mode 100644 index 00000000..55c2359e --- /dev/null +++ b/api/app/core/memory/models/metadata_models.py @@ -0,0 +1,57 @@ +"""Models for user metadata extraction. + +Independent from triplet_models.py - these models are used by the +standalone metadata extraction pipeline (post-dedup async Celery task). +""" + +from typing import List + +from pydantic import BaseModel, ConfigDict, Field + + +class UserMetadataProfile(BaseModel): + """用户画像信息""" + + model_config = ConfigDict(extra="ignore") + role: str = Field(default="", description="用户职业或角色") + domain: str = Field(default="", description="用户所在领域") + expertise: List[str] = Field( + default_factory=list, description="用户擅长的技能或工具" + ) + interests: List[str] = Field( + default_factory=list, description="用户关注的话题或领域标签" + ) + + +class UserMetadataBehavioralHints(BaseModel): + """行为偏好""" + + model_config = ConfigDict(extra="ignore") + learning_stage: str = Field(default="", description="学习阶段") + preferred_depth: str = Field(default="", description="偏好深度") + tone_preference: str = Field(default="", description="语气偏好") + + +class UserMetadata(BaseModel): + """用户元数据顶层结构""" + + model_config = ConfigDict(extra="ignore") + profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile) + behavioral_hints: UserMetadataBehavioralHints = Field( + default_factory=UserMetadataBehavioralHints + ) + knowledge_tags: List[str] = Field(default_factory=list, description="知识标签") + + +class MetadataExtractionResponse(BaseModel): + """元数据提取 LLM 响应结构""" + + model_config = ConfigDict(extra="ignore") + user_metadata: UserMetadata = Field(default_factory=UserMetadata) + aliases_to_add: List[str] = Field( + default_factory=list, + description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)", + ) + aliases_to_remove: List[str] = Field( + default_factory=list, description="用户明确否认的别名(如'我不叫XX了')" + ) diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index e4f0d4d0..4e2883d5 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -1,4 +1,3 @@ -import argparse import asyncio import json import math @@ -6,7 +5,6 @@ import os import time from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional -from uuid import UUID if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import ( ) from app.core.memory.utils.data.text_utils import extract_plain_query from app.core.memory.utils.data.time_utils import normalize_date_safe -from app.core.memory.utils.llm.llm_utils import get_reranker_client +# from app.core.memory.utils.llm.llm_utils import get_reranker_client from app.core.models.base import RedBearModelConfig from app.db import get_db_context from app.repositories.neo4j.graph_search import ( @@ -43,6 +41,7 @@ load_dotenv() logger = get_memory_logger(__name__) + def _parse_datetime(value: Any) -> Optional[datetime]: """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" if value is None: @@ -75,7 +74,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") if score_field == "activation_value" and score is None: scores.append(None) # 保持 None,稍后特殊处理 continue - + if score is not None and isinstance(score, (int, float)): scores.append(float(score)) else: @@ -83,10 +82,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") if not scores: return results - + # 过滤掉 None 值,只对有效分数进行归一化 valid_scores = [s for s in scores if s is not None] - + if not valid_scores: # 所有分数都是 None,不进行归一化 for item in results: @@ -94,7 +93,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") item[f"normalized_{score_field}"] = None return results - if len(valid_scores) == 1: # Single valid score, set to 1.0 + if len(valid_scores) == 1: # Single valid score, set to 1.0 for item, score in zip(results, scores): if score_field in item or score_field == "activation_value": if score is None: @@ -132,7 +131,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") return results - def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Remove duplicate items from search results based on content. @@ -150,52 +148,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: seen_ids = set() seen_content = set() deduplicated = [] - + for item in items: # Try multiple ID fields to identify unique items item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") - + # Extract content from various possible fields content = ( - item.get("text") or - item.get("content") or - item.get("statement") or - item.get("name") or - "" + item.get("text") or + item.get("content") or + item.get("statement") or + item.get("name") or + "" ) - + # Normalize content for comparison (strip whitespace and lowercase) normalized_content = str(content).strip().lower() if content else "" - + # Check if we've seen this ID or content before is_duplicate = False - + if item_id and item_id in seen_ids: is_duplicate = True elif normalized_content and normalized_content in seen_content: # Only check content duplication if content is not empty is_duplicate = True - + if not is_duplicate: # Mark as seen if item_id: seen_ids.add(item_id) if normalized_content: # Only track non-empty content seen_content.add(normalized_content) - + deduplicated.append(item) - + return deduplicated def rerank_with_activation( - keyword_results: Dict[str, List[Dict[str, Any]]], - embedding_results: Dict[str, List[Dict[str, Any]]], - alpha: float = 0.6, - limit: int = 10, - forgetting_config: ForgettingEngineConfig | None = None, - activation_boost_factor: float = 0.8, - now: datetime | None = None, + keyword_results: Dict[str, List[Dict[str, Any]]], + embedding_results: Dict[str, List[Dict[str, Any]]], + alpha: float = 0.6, + limit: int = 10, + forgetting_config: ForgettingEngineConfig | None = None, + activation_boost_factor: float = 0.8, + now: datetime | None = None, + content_score_threshold: float = 0.5, ) -> Dict[str, List[Dict[str, Any]]]: """ 两阶段排序:先按内容相关性筛选,再按激活值排序。 @@ -222,6 +221,8 @@ def rerank_with_activation( forgetting_config: 遗忘引擎配置(当前未使用) activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) now: 当前时间(用于遗忘计算) + content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score), + 低于此阈值的结果会被过滤。默认 0.5。 返回: 带评分元数据的重排序结果,按 final_score 排序 @@ -229,26 +230,26 @@ def rerank_with_activation( # 验证权重范围 if not (0 <= alpha <= 1): raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}") - + # 初始化遗忘引擎(如果需要) engine = None if forgetting_config: engine = ForgettingEngine(forgetting_config) now_dt = now or datetime.now() - + reranked: Dict[str, List[Dict[str, Any]]] = {} - + for category in ["statements", "chunks", "entities", "summaries", "communities"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) - + # 步骤 1: 归一化分数 keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - + # 步骤 2: 按 ID 合并结果(去重) combined_items: Dict[str, Dict[str, Any]] = {} - + # 添加关键词结果 for item in keyword_items: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") @@ -257,7 +258,7 @@ def rerank_with_activation( combined_items[item_id] = item.copy() combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) combined_items[item_id]["embedding_score"] = 0 # 默认值 - + # 添加或更新向量嵌入结果 for item in embedding_items: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") @@ -271,18 +272,18 @@ def rerank_with_activation( combined_items[item_id] = item.copy() combined_items[item_id]["bm25_score"] = 0 # 默认值 combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - + # 步骤 3: 归一化激活度分数 # 为所有项准备激活度值列表 items_list = list(combined_items.values()) items_list = normalize_scores(items_list, "activation_value") - + # 更新 combined_items 中的归一化激活度分数 for item in items_list: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") if item_id and item_id in combined_items: combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") - + # 步骤 4: 计算基础分数和最终分数 for item_id, item in combined_items.items(): bm25_norm = float(item.get("bm25_score", 0) or 0) @@ -290,45 +291,45 @@ def rerank_with_activation( # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 raw_act_norm = item.get("normalized_activation_value") act_norm = float(raw_act_norm) if raw_act_norm is not None else None - + # 第一阶段:只考虑内容相关性(BM25 + Embedding) # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 content_score = alpha * bm25_norm + (1 - alpha) * emb_norm base_score = content_score # 第一阶段用内容分数 - + # 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序) item["activation_score"] = act_norm # 可能为 None item["content_score"] = content_score item["base_score"] = base_score - + # 步骤 5: 应用遗忘曲线(可选) if engine: # 计算受激活度影响的记忆强度 importance = float(item.get("importance_score", 0.5) or 0.5) - + # 获取 activation_value activation_val = item.get("activation_value") - + # 只对有激活值的节点应用遗忘曲线 if activation_val is not None and isinstance(activation_val, (int, float)): activation_val = float(activation_val) - + # 计算记忆强度:importance_score × (1 + activation_value × boost_factor) memory_strength = importance * (1 + activation_val * activation_boost_factor) - + # 计算经过的时间(天数) dt = _parse_datetime(item.get("created_at")) if dt is None: time_elapsed_days = 0.0 else: time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) - + # 获取遗忘权重 forgetting_weight = engine.calculate_weight( time_elapsed=time_elapsed_days, memory_strength=memory_strength ) - + # 应用到基础分数 item["forgetting_weight"] = forgetting_weight item["final_score"] = base_score * forgetting_weight @@ -338,7 +339,7 @@ def rerank_with_activation( else: # 不使用遗忘曲线 item["final_score"] = base_score - + # 步骤 6: 两阶段排序和限制 # 第一阶段:按内容相关性(base_score)排序,取 Top-K first_stage_limit = limit * 3 # 可配置,取3倍候选 @@ -347,11 +348,11 @@ def rerank_with_activation( key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序 reverse=True )[:first_stage_limit] - + # 第二阶段:分离有激活值和无激活值的节点 items_with_activation = [] items_without_activation = [] - + for item in first_stage_sorted: activation_score = item.get("activation_score") # 检查是否有有效的激活值(不是 None) @@ -359,14 +360,14 @@ def rerank_with_activation( items_with_activation.append(item) else: items_without_activation.append(item) - + # 优先按激活值排序有激活值的节点 sorted_with_activation = sorted( items_with_activation, key=lambda x: float(x.get("activation_score", 0) or 0), reverse=True ) - + # 如果有激活值的节点不足 limit,用无激活值的节点补充 if len(sorted_with_activation) < limit: needed = limit - len(sorted_with_activation) @@ -374,7 +375,7 @@ def rerank_with_activation( sorted_items = sorted_with_activation + items_without_activation[:needed] else: sorted_items = sorted_with_activation[:limit] - + # 两阶段排序完成,更新 final_score 以反映实际排序依据 # Stage 1: 按 content_score 筛选候选(已完成) # Stage 2: 按 activation_score 排序(已完成) @@ -390,16 +391,29 @@ def rerank_with_activation( else: # 无激活值:使用内容相关性分数 item["final_score"] = item.get("base_score", 0) - - # 最终去重确保没有重复项 + + if content_score_threshold > 0: + before_count = len(sorted_items) + sorted_items = [ + item for item in sorted_items + if float(item.get("content_score", 0) or 0) >= content_score_threshold + ] + filtered_count = before_count - len(sorted_items) + if filtered_count > 0: + logger.info( + f"[rerank] {category}: filtered {filtered_count}/{before_count} " + f"items below content_score_threshold={content_score_threshold}" + ) + sorted_items = _deduplicate_results(sorted_items) - + reranked[category] = sorted_items - + return reranked -def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): +def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], + log_file: str = None): """Log search query information using the logger. Args: @@ -412,7 +426,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None, """ # Ensure the query text is plain and clean before logging cleaned_query = extract_plain_query(query_text) - + # Log using the standard logger logger.info( f"Search query: query='{cleaned_query}', type={search_type}, " @@ -439,8 +453,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any: def apply_reranker_placeholder( - results: Dict[str, List[Dict[str, Any]]], - query_text: str, + results: Dict[str, List[Dict[str, Any]]], + query_text: str, ) -> Dict[str, List[Dict[str, Any]]]: """ Placeholder for a cross-encoder reranker. @@ -483,7 +497,7 @@ def apply_reranker_placeholder( # ) -> Dict[str, List[Dict[str, Any]]]: # """ # Apply LLM-based reranking to search results. - + # Args: # results: Search results organized by category # query_text: Original search query @@ -491,7 +505,7 @@ def apply_reranker_placeholder( # llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) # top_k: Maximum number of items to rerank per category # batch_size: Number of items to process concurrently - + # Returns: # Reranked results with final_score and reranker_model fields # """ @@ -501,18 +515,18 @@ def apply_reranker_placeholder( # # except Exception as e: # # logger.debug(f"Failed to load reranker config: {e}") # # rc = {} - + # # Check if reranking is enabled # enabled = rc.get("enabled", False) # if not enabled: # logger.debug("LLM reranking is disabled in configuration") # return results - + # # Load configuration parameters with defaults # llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) # top_k = top_k if top_k is not None else rc.get("top_k", 20) # batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) - + # # Initialize reranker client if not provided # if reranker_client is None: # try: @@ -520,10 +534,10 @@ def apply_reranker_placeholder( # except Exception as e: # logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") # return results - + # # Get model name for metadata # model_name = getattr(reranker_client, 'model_name', 'unknown') - + # # Process each category # reranked_results = {} # for category in ["statements", "chunks", "entities", "summaries"]: @@ -531,38 +545,38 @@ def apply_reranker_placeholder( # if not items: # reranked_results[category] = [] # continue - + # # Select top K items by combined_score for reranking # sorted_items = sorted( # items, # key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), # reverse=True # ) - + # top_items = sorted_items[:top_k] # remaining_items = sorted_items[top_k:] - + # # Extract text content from each item # def extract_text(item: Dict[str, Any]) -> str: # """Extract text content from a result item.""" # # Try different text fields based on category # text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" # return str(text).strip() - + # # Batch items for concurrent processing # batches = [] # for i in range(0, len(top_items), batch_size): # batch = top_items[i:i + batch_size] # batches.append(batch) - + # # Process batches concurrently # async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # """Process a batch of items with LLM relevance scoring.""" # scored_batch = [] - + # for item in batch: # item_text = extract_text(item) - + # # Skip items with no text # if not item_text: # item_copy = item.copy() @@ -572,7 +586,7 @@ def apply_reranker_placeholder( # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) # continue - + # # Create relevance scoring prompt # prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. @@ -585,15 +599,15 @@ def apply_reranker_placeholder( # - 1.0 means perfectly relevant # Relevance score:""" - + # # Send request to LLM # try: # messages = [{"role": "user", "content": prompt}] # response = await reranker_client.chat(messages) - + # # Parse LLM response to extract relevance score # response_text = str(response.content if hasattr(response, 'content') else response).strip() - + # # Try to extract a float from the response # try: # # Remove any non-numeric characters except decimal point @@ -608,11 +622,11 @@ def apply_reranker_placeholder( # except (ValueError, AttributeError) as e: # logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") # llm_score = None - + # # Calculate final score # item_copy = item.copy() # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - + # if llm_score is not None: # final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score # item_copy["llm_relevance_score"] = llm_score @@ -620,7 +634,7 @@ def apply_reranker_placeholder( # # Use combined_score as fallback # final_score = combined_score # item_copy["llm_relevance_score"] = combined_score - + # item_copy["final_score"] = final_score # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) @@ -632,14 +646,14 @@ def apply_reranker_placeholder( # item_copy["llm_relevance_score"] = combined_score # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) - + # return scored_batch - + # # Process all batches concurrently # try: # batch_tasks = [process_batch(batch) for batch in batches] # batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) - + # # Merge batch results # scored_items = [] # for result in batch_results: @@ -647,7 +661,7 @@ def apply_reranker_placeholder( # logger.warning(f"Batch processing failed: {result}") # continue # scored_items.extend(result) - + # # Add remaining items (not in top K) with their combined_score as final_score # for item in remaining_items: # item_copy = item.copy() @@ -655,11 +669,11 @@ def apply_reranker_placeholder( # item_copy["final_score"] = combined_score # item_copy["reranker_model"] = model_name # scored_items.append(item_copy) - + # # Sort all items by final_score in descending order # scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) # reranked_results[category] = scored_items - + # except Exception as e: # logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") # # Return original items with combined_score as final_score @@ -668,22 +682,22 @@ def apply_reranker_placeholder( # item["final_score"] = combined_score # item["reranker_model"] = model_name # reranked_results[category] = items - + # return reranked_results async def run_hybrid_search( - query_text: str, - search_type: str, - end_user_id: str | None, - limit: int, - include: List[str], - output_path: str | None, - memory_config: "MemoryConfig", - rerank_alpha: float = 0.6, - activation_boost_factor: float = 0.8, - use_forgetting_rerank: bool = False, - use_llm_rerank: bool = False, + query_text: str, + search_type: str, + end_user_id: str | None, + limit: int, + include: List[str], + output_path: str | None, + memory_config: "MemoryConfig", + rerank_alpha: float = 0.6, + activation_boost_factor: float = 0.8, + use_forgetting_rerank: bool = False, + use_llm_rerank: bool = False, ): """ @@ -699,7 +713,7 @@ async def run_hybrid_search( # Clean and normalize the incoming query before use/logging query_text = extract_plain_query(query_text) - + # Validate query is not empty after cleaning if not query_text or not query_text.strip(): logger.warning("Empty query after cleaning, returning empty results") @@ -716,7 +730,7 @@ async def run_hybrid_search( "error": "Empty query" } } - + # Log the search query log_search_query(query_text, search_type, end_user_id, limit, include) @@ -732,11 +746,10 @@ async def run_hybrid_search( if search_type in ["keyword", "hybrid"]: # Keyword-based search logger.info("[PERF] Starting keyword search...") - keyword_start = time.time() keyword_task = asyncio.create_task( search_graph( connector=connector, - q=query_text, + query=query_text, end_user_id=end_user_id, limit=limit, include=include @@ -746,8 +759,7 @@ async def run_hybrid_search( if search_type in ["embedding", "hybrid"]: # Embedding-based search logger.info("[PERF] Starting embedding search...") - embedding_start = time.time() - + # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig config_load_start = time.time() try: @@ -758,8 +770,7 @@ async def run_hybrid_search( model_name=embedder_config_dict["model_name"], provider=embedder_config_dict["provider"], api_key=embedder_config_dict["api_key"], - base_url=embedder_config_dict["base_url"], - type="llm" + base_url=embedder_config_dict["base_url"] ) config_load_time = time.time() - config_load_start logger.info(f"[PERF] Config loading took {config_load_time:.4f}s") @@ -769,7 +780,7 @@ async def run_hybrid_search( embedder = OpenAIEmbedderClient(model_config=rb_config) embedder_init_time = time.time() - embedder_init_start logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") - + embedding_task = asyncio.create_task( search_graph_by_embedding( connector=connector, @@ -789,7 +800,7 @@ async def run_hybrid_search( if keyword_task: keyword_results = await keyword_task - keyword_latency = time.time() - keyword_start + keyword_latency = time.time() - search_start_time latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") if search_type == "keyword": @@ -799,7 +810,7 @@ async def run_hybrid_search( if embedding_task: embedding_results = await embedding_task - embedding_latency = time.time() - embedding_start + embedding_latency = time.time() - search_start_time latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") if search_type == "embedding": @@ -811,7 +822,8 @@ async def run_hybrid_search( if search_type == "hybrid": results["combined_summary"] = { "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), + "total_embedding_results": sum( + len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "search_query": query_text, "search_timestamp": datetime.now().isoformat() } @@ -819,7 +831,7 @@ async def run_hybrid_search( # Apply two-stage reranking with ACTR activation calculation rerank_start = time.time() logger.info("[PERF] Using two-stage reranking with ACTR activation") - + # 加载遗忘引擎配置 config_start = time.time() try: @@ -830,7 +842,7 @@ async def run_hybrid_search( forgetting_cfg = ForgettingEngineConfig() config_time = time.time() - config_start logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s") - + # 统一使用激活度重排序(两阶段:检索 + ACTR计算) rerank_compute_start = time.time() reranked_results = rerank_with_activation( @@ -843,14 +855,14 @@ async def run_hybrid_search( ) rerank_compute_time = time.time() - rerank_compute_start logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s") - + rerank_latency = time.time() - rerank_start latency_metrics["reranking_latency"] = round(rerank_latency, 4) logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s") - + # Optional: apply reranker placeholder if enabled via config reranked_results = apply_reranker_placeholder(reranked_results, query_text) - + # Apply LLM reranking if enabled llm_rerank_applied = False # if use_llm_rerank: @@ -863,11 +875,12 @@ async def run_hybrid_search( # logger.info("LLM reranking applied successfully") # except Exception as e: # logger.warning(f"LLM reranking failed: {e}, using previous scores") - + results["reranked_results"] = reranked_results results["combined_summary"] = { "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), + "total_embedding_results": sum( + len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), "search_query": query_text, "search_timestamp": datetime.now().isoformat(), @@ -880,17 +893,17 @@ async def run_hybrid_search( # Calculate total latency total_latency = time.time() - search_start_time latency_metrics["total_latency"] = round(total_latency, 4) - + # Add latency metrics to results if "combined_summary" in results: results["combined_summary"]["latency_metrics"] = latency_metrics else: results["latency_metrics"] = latency_metrics - - logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") + + logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") logger.info(f"[PERF] Total search completed in {total_latency:.4f}s") logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}") - logger.info(f"[PERF] =========================================") + logger.info("[PERF] =========================================") # Sanitize results: drop large/unused fields _remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs @@ -909,8 +922,10 @@ async def run_hybrid_search( # Log search completion with result count if search_type == "hybrid": result_counts = { - "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, - "embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} + "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in + keyword_results.items()}, + "embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in + embedding_results.items()} } else: result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} @@ -928,12 +943,12 @@ async def run_hybrid_search( async def search_by_temporal( - end_user_id: Optional[str] = "test", - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, + end_user_id: Optional[str] = "test", + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 1, ): """ Temporal search across Statements. @@ -969,13 +984,13 @@ async def search_by_temporal( async def search_by_keyword_temporal( - query_text: str, - end_user_id: Optional[str] = "test", - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, + query_text: str, + end_user_id: Optional[str] = "test", + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 1, ): """ Temporal keyword search across Statements. @@ -1012,9 +1027,9 @@ async def search_by_keyword_temporal( async def search_chunk_by_chunk_id( - chunk_id: str, - end_user_id: Optional[str] = "test", - limit: int = 1, + chunk_id: str, + end_user_id: Optional[str] = "test", + limit: int = 1, ): """ Search for Chunks by chunk_id. @@ -1027,4 +1042,3 @@ async def search_chunk_by_chunk_id( limit=limit ) return {"chunks": chunks} - diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 622f6e05..715f190c 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -4,6 +4,7 @@ import asyncio import difflib # 提供字符串相似度计算工具 import importlib +import logging import os import re from datetime import datetime @@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import ( ) from app.core.memory.models.variate_config import DedupConfig +logger = logging.getLogger(__name__) + # 模块级类型统一工具函数 def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None: @@ -79,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): canonical.connect_strength = next(iter(pair)) # 别名合并(去重保序,使用标准化工具) + # 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改 try: canonical_name = (getattr(canonical, "name", "") or "").strip() - incoming_name = (getattr(ent, "name", "") or "").strip() - - # 收集所有需要合并的别名 - all_aliases = [] - - # 1. 添加canonical现有的别名 - existing = getattr(canonical, "aliases", []) or [] - all_aliases.extend(existing) - - # 2. 添加incoming实体的名称(如果不同于canonical的名称) - if incoming_name and incoming_name != canonical_name: - all_aliases.append(incoming_name) - - # 3. 添加incoming实体的所有别名 - incoming = getattr(ent, "aliases", []) or [] - all_aliases.extend(incoming) - - # 4. 标准化并去重(优先使用alias_utils工具函数) - try: - from app.core.memory.utils.alias_utils import normalize_aliases - canonical.aliases = normalize_aliases(canonical_name, all_aliases) - except Exception: - # 如果导入失败,使用增强的去重逻辑 - seen_normalized = set() - unique_aliases = [] + if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES: + incoming_name = (getattr(ent, "name", "") or "").strip() - for alias in all_aliases: - if not alias: - continue - - alias_stripped = str(alias).strip() - if not alias_stripped or alias_stripped == canonical_name: - continue - - # 标准化:转小写用于去重判断 - alias_normalized = alias_stripped.lower() - - if alias_normalized not in seen_normalized: - seen_normalized.add(alias_normalized) - unique_aliases.append(alias_stripped) + # 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体 + all_aliases = list(getattr(canonical, "aliases", []) or []) + if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES: + all_aliases.append(incoming_name) + all_aliases.extend( + a for a in (getattr(ent, "aliases", []) or []) + if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES + ) - # 排序并赋值 - canonical.aliases = sorted(unique_aliases) + try: + from app.core.memory.utils.alias_utils import normalize_aliases + canonical.aliases = normalize_aliases(canonical_name, all_aliases) + except Exception: + seen_normalized = set() + unique_aliases = [] + for alias in all_aliases: + if not alias: + continue + alias_stripped = str(alias).strip() + if not alias_stripped or alias_stripped == canonical_name: + continue + alias_normalized = alias_stripped.lower() + if alias_normalized not in seen_normalized: + seen_normalized.add(alias_normalized) + unique_aliases.append(alias_stripped) + canonical.aliases = sorted(unique_aliases) except Exception: pass @@ -198,6 +188,161 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): except Exception: pass +# 用户和AI助手的占位名称集合(用于名称标准化) +_USER_PLACEHOLDER_NAMES = {"用户", "我", "user", "i"} +_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"} + +# 标准化后的规范名称和类型 +_CANONICAL_USER_NAME = "用户" +_CANONICAL_USER_TYPE = "用户" +_CANONICAL_ASSISTANT_NAME = "AI助手" +_CANONICAL_ASSISTANT_TYPE = "Agent" + +# 用户和AI助手的所有可能名称(用于判断实体是否为特殊角色实体) +_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES +_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES + + +def _is_user_entity(ent: ExtractedEntityNode) -> bool: + """判断实体是否为用户实体(name 或 entity_type 匹配)""" + name = (getattr(ent, "name", "") or "").strip().lower() + etype = (getattr(ent, "entity_type", "") or "").strip() + return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE + + +def _is_assistant_entity(ent: ExtractedEntityNode) -> bool: + """判断实体是否为AI助手实体(name 或 entity_type 匹配)""" + name = (getattr(ent, "name", "") or "").strip().lower() + etype = (getattr(ent, "entity_type", "") or "").strip() + return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE + + +def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool: + """判断两个实体的合并是否会跨越用户/AI助手角色边界。 + + 用户实体和AI助手实体永远不应该被合并在一起。 + 如果一方是用户实体、另一方是AI助手实体,返回 True(阻止合并)。 + """ + return ( + (_is_user_entity(a) and _is_assistant_entity(b)) + or (_is_assistant_entity(a) and _is_user_entity(b)) + ) + + +def _normalize_special_entity_names( + entity_nodes: List[ExtractedEntityNode], +) -> None: + """标准化用户和AI助手实体的名称和类型。 + + 多轮对话中,LLM 对同一角色可能使用不同的名称变体(如"用户"/"我"/"User", + "AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。 + 此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type,确保: + - name="用户" 的实体 entity_type 一定为 "用户" + - name="AI助手" 的实体 entity_type 一定为 "Agent" + + Args: + entity_nodes: 实体节点列表(原地修改) + """ + for ent in entity_nodes: + name = (getattr(ent, "name", "") or "").strip() + name_lower = name.lower() + + if name_lower in _USER_PLACEHOLDER_NAMES: + ent.name = _CANONICAL_USER_NAME + ent.entity_type = _CANONICAL_USER_TYPE + elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES: + ent.name = _CANONICAL_ASSISTANT_NAME + ent.entity_type = _CANONICAL_ASSISTANT_TYPE + + # 第二步:清洗用户/AI助手之间的别名交叉污染(复用 clean_cross_role_aliases) + clean_cross_role_aliases(entity_nodes) + + +async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set: + """从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。 + + 这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用, + 避免多处维护相同的 Cypher 和名称列表。 + + Args: + neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法) + end_user_id: 终端用户 ID + + Returns: + 小写归一化后的助手别名集合 + """ + # 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致) + query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES] + # 去重保序 + query_names = list(dict.fromkeys(query_names)) + + cypher = """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND e.name IN $names + RETURN e.aliases AS aliases + """ + try: + result = await neo4j_connector.execute_query( + cypher, end_user_id=end_user_id, names=query_names + ) + assistant_aliases: set = set() + for record in (result or []): + for alias in (record.get("aliases") or []): + assistant_aliases.add(alias.strip().lower()) + if assistant_aliases: + logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}") + return assistant_aliases + except Exception as e: + logger.warning(f"查询 Neo4j AI 助手别名失败: {e}") + return set() + + +def clean_cross_role_aliases( + entity_nodes: List[ExtractedEntityNode], + external_assistant_aliases: set = None, +) -> None: + """清洗用户实体和AI助手实体之间的别名交叉污染。 + + 在 Neo4j 写入前调用,确保: + - 用户实体的 aliases 不包含 AI 助手的别名 + - AI 助手实体的 aliases 不包含用户的别名 + + Args: + entity_nodes: 实体节点列表(原地修改) + external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询), + 与本轮实体中的 AI 助手别名合并使用 + """ + # 收集本轮 AI 助手实体的所有别名 + assistant_aliases = set(external_assistant_aliases or set()) + user_aliases = set() + + for ent in entity_nodes: + if _is_assistant_entity(ent): + for alias in (getattr(ent, "aliases", []) or []): + assistant_aliases.add(alias.strip().lower()) + elif _is_user_entity(ent): + for alias in (getattr(ent, "aliases", []) or []): + user_aliases.add(alias.strip().lower()) + + # 从用户实体的 aliases 中移除 AI 助手别名 + if assistant_aliases: + for ent in entity_nodes: + if _is_user_entity(ent): + original = getattr(ent, "aliases", []) or [] + cleaned = [a for a in original if a.strip().lower() not in assistant_aliases] + if len(cleaned) < len(original): + ent.aliases = cleaned + + # 从 AI 助手实体的 aliases 中移除用户别名 + if user_aliases: + for ent in entity_nodes: + if _is_assistant_entity(ent): + original = getattr(ent, "aliases", []) or [] + cleaned = [a for a in original if a.strip().lower() not in user_aliases] + if len(cleaned) < len(original): + ent.aliases = cleaned + + def accurate_match( entity_nodes: List[ExtractedEntityNode] ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: @@ -261,6 +406,10 @@ def accurate_match( canonical = alias_index.get((ent_uid, ent_name)) # 确保不是自身 if canonical is not None and canonical.id != ent.id: + # 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并) + if _would_merge_cross_role(canonical, ent): + i += 1 + continue _merge_attribute(canonical, ent) id_redirect[ent.id] = canonical.id for k, v in list(id_redirect.items()): @@ -571,66 +720,37 @@ def fuzzy_match( def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode): - """ 模糊匹配中的实体合并。 + """模糊匹配中的实体合并(别名部分)。 - 合并策略: - 1. 保留canonical的主名称不变 - 2. 将losing的主名称添加为alias(如果不同) - 3. 合并两个实体的所有aliases - 4. 自动去重(case-insensitive)并排序 - - Args: - canonical: 规范实体(保留) - losing: 被合并实体(删除) - - Note: - 使用alias_utils.normalize_aliases进行标准化去重 + 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。 """ - # 获取规范实体的名称 canonical_name = (getattr(canonical, "name", "") or "").strip() + if canonical_name.lower() in _USER_PLACEHOLDER_NAMES: + return + losing_name = (getattr(losing, "name", "") or "").strip() - # 收集所有需要合并的别名 - all_aliases = [] - - # 1. 添加canonical现有的别名 - current_aliases = getattr(canonical, "aliases", []) or [] - all_aliases.extend(current_aliases) - - # 2. 添加losing实体的名称(如果不同于canonical的名称) + all_aliases = list(getattr(canonical, "aliases", []) or []) if losing_name and losing_name != canonical_name: all_aliases.append(losing_name) + all_aliases.extend(getattr(losing, "aliases", []) or []) - # 3. 添加losing实体的所有别名 - losing_aliases = getattr(losing, "aliases", []) or [] - all_aliases.extend(losing_aliases) - - # 4. 标准化并去重(使用标准化后的字符串进行去重) try: from app.core.memory.utils.alias_utils import normalize_aliases canonical.aliases = normalize_aliases(canonical_name, all_aliases) except Exception: - # 如果导入失败,使用增强的去重逻辑 - # 使用标准化后的字符串作为key进行去重 seen_normalized = set() unique_aliases = [] - for alias in all_aliases: if not alias: continue - alias_stripped = str(alias).strip() if not alias_stripped or alias_stripped == canonical_name: continue - - # 标准化:转小写用于去重判断 alias_normalized = alias_stripped.lower() - if alias_normalized not in seen_normalized: seen_normalized.add(alias_normalized) unique_aliases.append(alias_stripped) - - # 排序并赋值 canonical.aliases = sorted(unique_aliases) # ========== 主循环:遍历所有实体对进行模糊匹配 ========== @@ -704,6 +824,11 @@ def fuzzy_match( # 条件A(快速通道):alias_match_merge = True # 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover): + # 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并) + if _would_merge_cross_role(a, b): + j += 1 + continue + # ========== 第六步:执行实体合并 ========== # 6.1 合并别名 @@ -813,6 +938,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 b = entity_by_id.get(losing_id) if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录 continue + # 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并) + if _would_merge_cross_role(a, b): + llm_records.append( + f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})" + ) + continue _merge_attribute(a, b) # ID 重定向 try: @@ -934,6 +1065,9 @@ async def deduplicate_entities_and_edges( 返回:去重后的实体、语句→实体边、实体↔实体边。 """ local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯 + # 0) 标准化用户和AI助手实体名称(确保多轮对话中的变体名称统一) + _normalize_special_entity_names(entity_nodes) + # 1) 精确匹配 deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index 4b9c5718..13534b3d 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData from app.core.memory.models.variate_config import ExtractionPipelineConfig from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( deduplicate_entities_and_edges, + clean_cross_role_aliases, ) from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import ( second_layer_dedup_and_merge_with_neo4j, @@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return( except Exception as e: print(f"Second-layer dedup failed: {e}") + # 第二层去重后,清洗用户/AI助手之间的别名交叉污染 + # 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据 + clean_cross_role_aliases(fused_entity_nodes) + return ( dialogue_nodes, chunk_nodes, diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index b20112a2..75fc87d2 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -44,6 +44,10 @@ from app.core.memory.models.variate_config import ( from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( dedup_layers_and_merge_and_return, ) +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( + _USER_PLACEHOLDER_NAMES, + fetch_neo4j_assistant_aliases, +) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, generate_entity_embeddings_from_triplets, @@ -307,10 +311,53 @@ class ExtractionOrchestrator: dialog_data_list, ) - # 步骤 7: 同步用户别名到数据库表(仅正式模式) + # 步骤 7: 触发异步元数据和别名提取(仅正式模式) if not is_pilot_run: - logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表") - await self._update_end_user_other_name(entity_nodes, dialog_data_list) + try: + from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import ( + MetadataExtractor, + ) + + metadata_extractor = MetadataExtractor( + llm_client=self.llm_client, language=self.language + ) + user_statements = ( + metadata_extractor.collect_user_related_statements( + entity_nodes, statement_nodes, statement_entity_edges + ) + ) + if user_statements: + end_user_id = ( + dialog_data_list[0].end_user_id + if dialog_data_list + else None + ) + config_id = ( + dialog_data_list[0].config_id + if dialog_data_list + and hasattr(dialog_data_list[0], "config_id") + else None + ) + if end_user_id: + from app.tasks import extract_user_metadata_task + + extract_user_metadata_task.delay( + end_user_id=str(end_user_id), + statements=user_statements, + config_id=str(config_id) if config_id else None, + language=self.language, + ) + logger.info( + f"已触发异步元数据提取任务,共 {len(user_statements)} 条用户相关 statement" + ) + else: + logger.info("未找到用户相关 statement,跳过元数据提取") + except Exception as e: + logger.error( + f"触发元数据提取任务失败(不影响主流程): {e}", exc_info=True + ) + + # 别名同步已迁移到 Celery 元数据提取任务中,不再在此处执行 logger.info(f"知识提取流水线运行完成({mode_str})") return ( @@ -1103,6 +1150,7 @@ class ExtractionOrchestrator: end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=chunk.content, + speaker=getattr(chunk, 'speaker', None), chunk_embedding=chunk.chunk_embedding, sequence_number=chunk_idx, # 添加必需的 sequence_number 字段 created_at=dialog_data.created_at, @@ -1338,17 +1386,23 @@ class ExtractionOrchestrator: async def _update_end_user_other_name( self, entity_nodes: List[ExtractedEntityNode], - dialog_data_list: List[DialogData] + dialog_data_list: List[DialogData], ) -> None: """ - 从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表 + 将本轮提取的用户别名同步到 end_user 和 end_user_info 表。 - 注意: - 1. other_name 使用本次对话提取的第一个别名(保持时间顺序) - 2. aliases 从 Neo4j 读取(保持完整性) + PgSQL end_user_info.aliases 是用户别名的唯一权威源。 + 此方法仅将本轮 LLM 从对话中新提取的别名增量追加到 PgSQL, + 不再从 Neo4j 二层去重合并历史别名,避免脏数据反向污染 PgSQL。 + + 策略: + 1. 从本轮对话原始发言中提取用户别名(current_aliases) + 2. 从 PgSQL end_user_info 读取已有的 aliases(db_aliases) + 3. 合并 db_aliases + current_aliases,去重保序 + 4. 写回 PgSQL Args: - entity_nodes: 实体节点列表 + entity_nodes: 去重后的实体节点列表(内存中) dialog_data_list: 对话数据列表 """ try: @@ -1361,23 +1415,28 @@ class ExtractionOrchestrator: logger.warning("end_user_id 为空,跳过用户别名同步") return - # 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序) - current_aliases = self._extract_current_aliases(entity_nodes) + # 1. 提取本轮对话的用户别名(保持 LLM 提取的原始顺序,不排序) + current_aliases = self._extract_current_aliases(entity_nodes, dialog_data_list) - # 2. 从 Neo4j 获取完整 aliases(权威数据源) - neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id) + # 1.6 从 Neo4j 查询已有的 AI 助手别名,作为额外的排除源 + # (防止 LLM 未提取出 AI 助手实体时,AI 别名泄漏到用户别名中) + neo4j_assistant_aliases = await self._fetch_neo4j_assistant_aliases(end_user_id) + if neo4j_assistant_aliases: + before_count = len(current_aliases) + current_aliases = [ + a for a in current_aliases + if a.strip().lower() not in neo4j_assistant_aliases + ] + if len(current_aliases) < before_count: + logger.info(f"通过 Neo4j AI 助手别名排除了 {before_count - len(current_aliases)} 个误归属别名") - if not neo4j_aliases: - # Neo4j 中没有别名,使用本次对话提取的别名 - neo4j_aliases = current_aliases - if not neo4j_aliases: - logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}") - return + if not current_aliases: + logger.debug(f"本轮未提取到用户别名,跳过同步: end_user_id={end_user_id}") + return - logger.info(f"本次对话提取的 aliases: {current_aliases}") - logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}") + logger.info(f"本轮对话提取的 aliases: {current_aliases}") - # 3. 同步到数据库 + # 2. 同步到数据库 end_user_uuid = uuid.UUID(end_user_id) with get_db_context() as db: # 更新 end_user 表 @@ -1386,7 +1445,32 @@ class ExtractionOrchestrator: logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录") return - new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases) + # 3. 从 PgSQL 读取已有 aliases 并与本轮新增合并 + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + db_aliases = (info.aliases if info and info.aliases else []) + # 过滤掉占位名称 + db_aliases = [a for a in db_aliases if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES] + + # 合并:PgSQL 已有 + 本轮新增,去重保序(不再合并 Neo4j 历史别名) + merged_aliases = list(db_aliases) + seen_lower = {a.strip().lower() for a in merged_aliases} + for alias in current_aliases: + if alias.strip().lower() not in seen_lower: + merged_aliases.append(alias) + seen_lower.add(alias.strip().lower()) + + # 最终过滤:从合并结果中排除 AI 助手别名(清理历史脏数据) + if neo4j_assistant_aliases: + merged_aliases = [ + a for a in merged_aliases + if a.strip().lower() not in neo4j_assistant_aliases + ] + + logger.info(f"PgSQL 已有 aliases: {db_aliases}") + logger.info(f"合并后 aliases: {merged_aliases}") + + # 更新 end_user 表 other_name + new_name = self._resolve_other_name(end_user.other_name, current_aliases, merged_aliases) if new_name is not None: end_user.other_name = new_name logger.info(f"更新 end_user 表 other_name → {new_name}") @@ -1394,78 +1478,105 @@ class ExtractionOrchestrator: logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}") # 更新或创建 end_user_info 记录 - info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) if info: - new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases) + new_name_info = self._resolve_other_name(info.other_name, current_aliases, merged_aliases) if new_name_info is not None: info.other_name = new_name_info logger.info(f"更新 end_user_info 表 other_name → {new_name_info}") - if info.aliases != neo4j_aliases: - info.aliases = neo4j_aliases - logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") + if info.aliases != merged_aliases: + info.aliases = merged_aliases + logger.info(f"同步合并后 aliases 到 end_user_info: {merged_aliases}") else: first_alias = current_aliases[0].strip() if current_aliases else "" # 确保 first_alias 不是占位名称 - if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES: + if first_alias and first_alias.lower() not in self.USER_PLACEHOLDER_NAMES: db.add(EndUserInfo( end_user_id=end_user_uuid, other_name=first_alias, - aliases=neo4j_aliases, - meta_data={} + aliases=merged_aliases, )) - logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}") + logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={merged_aliases}") db.commit() except Exception as e: logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True) - - - # 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中 - USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'} + # 复用 deduped_and_disamb 模块级常量,避免重复维护 + USER_PLACEHOLDER_NAMES = _USER_PLACEHOLDER_NAMES - def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: - """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) + def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode], dialog_data_list=None) -> List[str]: + """从用户发言的原始实体中提取本轮新增别名(绕过去重污染) - 这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。 - 第一个别名将被用作 other_name。 + 策略: + 仅从 dialog_data_list 中找到 speaker="user" 的 statement, + 从这些 statement 的 triplet_extraction_info 中提取用户实体的 aliases。 + 这样拿到的是 LLM 对用户原话的提取结果,不受去重合并的影响。 + + 注意:不再使用去重后 entity_nodes 作为兜底,因为二层去重会将 Neo4j 历史别名 + 合并进来,导致历史别名被误认为"本轮提取"。历史别名的同步由 + _extract_deduped_entity_aliases 负责。 Args: - entity_nodes: 实体节点列表 + entity_nodes: 去重后的实体节点列表(未使用,保留参数兼容性) + dialog_data_list: 对话数据列表 Returns: - 别名列表(保持 LLM 提取的原始顺序,已过滤占位名称) + 别名列表(保持原始顺序,已过滤) + """ + if not dialog_data_list: + return [] + + all_user_aliases = [] + seen_lower = set() + for dialog in dialog_data_list: + for chunk in dialog.chunks: + speaker = getattr(chunk, 'speaker', None) + for statement in chunk.statements: + stmt_speaker = getattr(statement, 'speaker', None) or speaker + if stmt_speaker != "user": + continue + triplet_info = getattr(statement, 'triplet_extraction_info', None) + if not triplet_info: + continue + for entity in (triplet_info.entities or []): + ent_name = getattr(entity, 'name', '').strip() + if ent_name.lower() in self.USER_PLACEHOLDER_NAMES: + for alias in (getattr(entity, 'aliases', []) or []): + a = alias.strip() + if a and a.lower() not in self.USER_PLACEHOLDER_NAMES and a.lower() not in seen_lower: + all_user_aliases.append(a) + seen_lower.add(a.lower()) + if all_user_aliases: + logger.debug(f"从用户原始发言提取到别名: {all_user_aliases}") + return all_user_aliases + + def _extract_deduped_entity_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: + """从去重后的用户实体中提取完整别名列表。 + + 二层去重会将 Neo4j 中已有的历史别名合并到 entity_nodes 的用户实体中, + 因此这里提取到的别名包含了历史积累的所有别名,可用于同步到 PgSQL。 + + Args: + entity_nodes: 去重后的实体节点列表(含二层去重合并结果) + + Returns: + 别名列表(已过滤占位名称,去重保序) """ for entity in entity_nodes: - if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES: + if getattr(entity, 'name', '').strip().lower() in self.USER_PLACEHOLDER_NAMES: aliases = getattr(entity, 'aliases', []) or [] - # 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name - filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] - logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}") - return filtered + filtered = [ + a for a in aliases + if a.strip().lower() not in self.USER_PLACEHOLDER_NAMES + ] + if filtered: + return filtered return [] - - async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: - """从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)""" - cypher = """ - MATCH (e:ExtractedEntity) - WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] - RETURN e.aliases AS aliases - LIMIT 1 - """ - result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id) - if not result: - logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}") - return [] - aliases = result[0].get('aliases') or [] - if not aliases: - logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") - return [] - # 过滤掉占位名称,防止历史脏数据传播 - filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] - return filtered + async def _fetch_neo4j_assistant_aliases(self, end_user_id: str) -> set: + """从 Neo4j 查询 AI 助手实体的所有别名(用于从用户别名中排除)""" + return await fetch_neo4j_assistant_aliases(self.connector, end_user_id) def _resolve_other_name( self, @@ -1484,19 +1595,18 @@ class ExtractionOrchestrator: 注意:返回值不允许是占位名称("用户"、"我"、"User"、"I") """ # 当前值为空或为占位名称时,需要更新 - if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES: + if not current or not current.strip() or current.strip().lower() in self.USER_PLACEHOLDER_NAMES: candidate = current_aliases[0].strip() if current_aliases else None # 确保候选值不是占位名称 - if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES: return None return candidate if current not in neo4j_aliases: candidate = neo4j_aliases[0].strip() if neo4j_aliases else None # 确保候选值不是占位名称 - if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + if candidate and candidate.lower() in self.USER_PLACEHOLDER_NAMES: return None return candidate - return None async def _run_dedup_and_write_summary( diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py new file mode 100644 index 00000000..19f1e533 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/metadata_extractor.py @@ -0,0 +1,175 @@ +""" +Metadata extractor module. + +Collects user-related statements from post-dedup graph data and +extracts user metadata via an independent LLM call. +""" + +import logging +from typing import List, Optional + +from app.core.memory.models.graph_models import ( + ExtractedEntityNode, + StatementEntityEdge, + StatementNode, +) + +logger = logging.getLogger(__name__) + +# Reuse the same user-entity detection logic from dedup module +_USER_NAMES = {"用户", "我", "user", "i"} +_CANONICAL_USER_TYPE = "用户" + + +def _is_user_entity(ent: ExtractedEntityNode) -> bool: + """判断实体是否为用户实体""" + name = (getattr(ent, "name", "") or "").strip().lower() + etype = (getattr(ent, "entity_type", "") or "").strip() + return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE + + +class MetadataExtractor: + """Extracts user metadata from post-dedup graph data via independent LLM call.""" + + def __init__(self, llm_client, language: Optional[str] = None): + self.llm_client = llm_client + self.language = language + + @staticmethod + def detect_language(statements: List[str]) -> str: + """根据 statement 文本内容检测语言。 + 如果文本中包含中文字符则返回 "zh",否则返回 "en"。 + """ + import re + + combined = " ".join(statements) + if re.search(r"[\u4e00-\u9fff]", combined): + return "zh" + return "en" + + def collect_user_related_statements( + self, + entity_nodes: List[ExtractedEntityNode], + statement_nodes: List[StatementNode], + statement_entity_edges: List[StatementEntityEdge], + ) -> List[str]: + """ + 从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。 + + 筛选逻辑: + 1. 用户实体 → StatementEntityEdge → statement(直接关联) + 2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声) + + Returns: + 用户发言的 statement 文本列表 + """ + # Find user entity IDs + user_entity_ids = set() + for ent in entity_nodes: + if _is_user_entity(ent): + user_entity_ids.add(ent.id) + + if not user_entity_ids: + logger.debug("未找到用户实体节点,跳过 statement 收集") + return [] + + # 用户实体 → StatementEntityEdge → statement + target_stmt_ids = set() + for edge in statement_entity_edges: + if edge.target in user_entity_ids: + target_stmt_ids.add(edge.source) + + # Collect: only speaker="user" statements, preserving order + result = [] + seen = set() + total_associated = 0 + skipped_non_user = 0 + for stmt_node in statement_nodes: + if stmt_node.id in target_stmt_ids and stmt_node.id not in seen: + total_associated += 1 + speaker = getattr(stmt_node, "speaker", None) or "unknown" + if speaker == "user": + text = (stmt_node.statement or "").strip() + if text: + result.append(text) + else: + skipped_non_user += 1 + seen.add(stmt_node.id) + + logger.info( + f"收集到 {len(result)} 条用户发言 statement " + f"(直接关联: {total_associated}, speaker=user: {len(result)}, " + f"跳过非user: {skipped_non_user})" + ) + if result: + for i, text in enumerate(result): + logger.info(f" [user statement {i + 1}] {text}") + if total_associated > 0 and len(result) == 0: + logger.warning( + f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤," + f"可能本次写入不包含 user 消息" + ) + return result + + async def extract_metadata( + self, + statements: List[str], + existing_metadata: Optional[dict] = None, + existing_aliases: Optional[List[str]] = None, + ) -> Optional[tuple]: + """ + 对筛选后的 statement 列表调用 LLM 提取元数据和用户别名。 + + Args: + statements: 用户发言的 statement 文本列表 + existing_metadata: 数据库已有的元数据(可选) + existing_aliases: 数据库已有的用户别名列表(可选) + + Returns: + (UserMetadata, List[str], List[str]) tuple: (metadata, aliases_to_add, aliases_to_remove) on success, None on failure + """ + if not statements: + return None + + try: + from app.core.memory.utils.prompt.prompt_utils import prompt_env + + if self.language: + detected_language = self.language + logger.info(f"元数据提取使用显式指定语言: {detected_language}") + else: + detected_language = self.detect_language(statements) + logger.info(f"元数据提取语言自动检测结果: {detected_language}") + + template = prompt_env.get_template("extract_user_metadata.jinja2") + prompt = template.render( + statements=statements, + language=detected_language, + existing_metadata=existing_metadata, + existing_aliases=existing_aliases, + json_schema="", + ) + + from app.core.memory.models.metadata_models import ( + MetadataExtractionResponse, + ) + + response = await self.llm_client.response_structured( + messages=[{"role": "user", "content": prompt}], + response_model=MetadataExtractionResponse, + ) + + if response: + metadata = response.user_metadata if response.user_metadata else None + to_add = response.aliases_to_add if response.aliases_to_add else [] + to_remove = ( + response.aliases_to_remove if response.aliases_to_remove else [] + ) + return metadata, to_add, to_remove + + logger.warning("LLM 返回的响应为空") + return None + + except Exception as e: + logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True) + return None diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index b06bd70f..d90a49ba 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from datetime import datetime from typing import Any, Dict, List, Optional @@ -82,6 +81,7 @@ class StatementExtractor: logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") return None + async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements @@ -94,7 +94,8 @@ class StatementExtractor: List of ExtractedStatement objects extracted from the chunk """ chunk_content = chunk.content - + chunk_speaker = self._get_speaker_from_chunk(chunk) + if not chunk_content or len(chunk_content.strip()) < 5: logger.warning(f"Chunk {chunk.id} content too short or empty, skipping") return [] @@ -149,8 +150,6 @@ class StatementExtractor: relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT except (KeyError, ValueError): relevence_info = RelevenceInfo.RELEVANT - - chunk_speaker = self._get_speaker_from_chunk(chunk) chunk_statement = Statement( statement=extracted_stmt.statement, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index 147ed777..ea355ca1 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -1,4 +1,3 @@ -import os import asyncio from typing import List, Dict, Optional @@ -61,6 +60,7 @@ class TripletExtractor: predicate_instructions=PREDICATE_DEFINITIONS, language=self._get_language(), ontology_types=self.ontology_types, + speaker=getattr(statement, 'speaker', None), ) # Create messages for LLM diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a71c0957..e5254646 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -42,22 +42,21 @@ class AccessHistoryManager: - access_count: 访问次数 特性: - - 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚 - - 并发安全:使用乐观锁机制防止并发冲突 + - 原子性更新:使用 APOC 原子操作确保并发安全 + - 批次内合并:同一批次中对同一节点的多次访问合并为一次更新 - 一致性保证:提供一致性检查和自动修复功能 - 智能修剪:自动修剪过长的访问历史 Attributes: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数 """ def __init__( self, connector: Neo4jConnector, actr_calculator: ACTRCalculator, - max_retries: int = 3 + max_retries: int = 5 ): """ 初始化访问历史管理器 @@ -65,47 +64,35 @@ class AccessHistoryManager: Args: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数(默认3次) + max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试) """ self.connector = connector self.actr_calculator = actr_calculator - self.max_retries = max_retries - + async def record_access( self, node_id: str, node_label: str, end_user_id: Optional[str] = None, - current_time: Optional[datetime] = None + current_time: Optional[datetime] = None, + access_times: int = 1 ) -> Dict[str, Any]: """ 记录节点访问并原子性更新所有相关字段 - 这是核心方法,实现了: - 1. 首次访问:初始化access_history,计算初始激活值 - 2. 后续访问:追加访问历史,重新计算激活值 - 3. 历史修剪:当历史过长时自动修剪 - 4. 原子性:所有字段在单个事务中更新 - 5. 并发安全:使用乐观锁重试机制 - Args: node_id: 节点ID node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) + access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: - Dict[str, Any]: 更新后的节点数据,包含: - - id: 节点ID - - activation_value: 更新后的激活值 - - access_history: 更新后的访问历史 - - last_access_time: 最后访问时间 - - access_count: 访问次数 - - importance_score: 重要性分数 + Dict[str, Any]: 更新后的节点数据 Raises: ValueError: 如果节点不存在或节点标签无效 - RuntimeError: 如果重试次数耗尽仍然失败 + RuntimeError: 如果更新失败 """ if current_time is None: current_time = datetime.now() @@ -119,55 +106,48 @@ class AccessHistoryManager: f"Invalid node_label: {node_label}. Must be one of {valid_labels}" ) - # 使用乐观锁重试机制处理并发冲突 - for attempt in range(self.max_retries): - try: - # 步骤1:读取当前节点状态 - node_data = await self._fetch_node(node_id, node_label, end_user_id) - - if not node_data: - raise ValueError( - f"Node not found: {node_label} with id={node_id}" - ) - - # 步骤2:计算新的访问历史和激活值 - update_data = await self._calculate_update( - node_data=node_data, - current_time=current_time, - current_time_iso=current_time_iso + try: + # 步骤1:读取当前节点状态 + node_data = await self._fetch_node(node_id, node_label, end_user_id) + + if not node_data: + raise ValueError( + f"Node not found: {node_label} with id={node_id}" ) - - # 步骤3:原子性更新节点(使用事务) - updated_node = await self._atomic_update( - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - - logger.info( - f"成功记录访问: {node_label}[{node_id}], " - f"activation={update_data['activation_value']:.4f}, " - f"access_count={update_data['access_count']}" - ) - - return updated_node - - except Exception as e: - if attempt < self.max_retries - 1: - logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}" - ) - continue - else: - logger.error( - f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], " - f"错误: {str(e)}" - ) - raise RuntimeError( - f"Failed to record access after {self.max_retries} attempts: {str(e)}" - ) - + + # 步骤2:计算新的访问历史和激活值 + update_data = await self._calculate_update( + node_data=node_data, + current_time=current_time, + current_time_iso=current_time_iso, + access_times=access_times + ) + + # 步骤3:使用 APOC 原子操作更新节点(无需重试) + updated_node = await self._atomic_update( + node_id=node_id, + node_label=node_label, + update_data=update_data, + end_user_id=end_user_id + ) + + logger.info( + f"成功记录访问: {node_label}[{node_id}], " + f"activation={update_data['activation_value']:.4f}, " + f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" + ) + + return updated_node + + except Exception as e: + logger.error( + f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + raise RuntimeError( + f"Failed to record access: {str(e)}" + ) from e + async def record_batch_access( self, node_ids: List[str], @@ -178,11 +158,10 @@ class AccessHistoryManager: """ 批量记录多个节点的访问 - 为提高性能,批量更新多个节点的访问历史。 - 每个节点独立更新,失败的节点不影响其他节点。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新。 Args: - node_ids: 节点ID列表 + node_ids: 节点ID列表(可包含重复ID) node_label: 节点标签(所有节点必须是同一类型) end_user_id: 组ID(可选) current_time: 当前时间(可选) @@ -196,25 +175,38 @@ class AccessHistoryManager: if current_time is None: current_time = datetime.now() - # PERFORMANCE FIX: Process all nodes in parallel instead of sequentially - tasks = [] + # 合并同一节点的访问次数,避免对同一节点并发写入 + access_count_map: Dict[str, int] = {} for node_id in node_ids: + access_count_map[node_id] = access_count_map.get(node_id, 0) + 1 + + merged_count = len(node_ids) - len(access_count_map) + if merged_count > 0: + logger.info( + f"批量访问合并: 原始={len(node_ids)}, " + f"去重后={len(access_count_map)}, 合并={merged_count}" + ) + + # 对去重后的节点并行发起更新 + tasks = [] + for node_id, access_times in access_count_map.items(): task = self.record_access( node_id=node_id, node_label=node_label, end_user_id=end_user_id, - current_time=current_time + current_time=current_time, + access_times=access_times ) - tasks.append(task) + tasks.append((node_id, task)) - # Execute all tasks in parallel - task_results = await asyncio.gather(*tasks, return_exceptions=True) + task_results = await asyncio.gather( + *[t for _, t in tasks], return_exceptions=True + ) - # Collect successful results and count failures results = [] failed_count = 0 - for node_id, result in zip(node_ids, task_results): + for (node_id, _), result in zip(tasks, task_results): if isinstance(result, Exception): failed_count += 1 logger.warning( @@ -225,12 +217,12 @@ class AccessHistoryManager: batch_duration = time.time() - batch_start logger.info( - f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, " f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" ) return results - + async def check_consistency( self, node_id: str, @@ -239,22 +231,6 @@ class AccessHistoryManager: ) -> Tuple[ConsistencyCheckResult, Optional[str]]: """ 检查节点数据的一致性 - - 验证以下一致性规则: - 1. access_history[-1] == last_access_time - 2. len(access_history) == access_count - 3. 如果有访问历史,必须有激活值 - 4. 激活值必须在有效范围内 [offset, 1.0] - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Tuple[ConsistencyCheckResult, Optional[str]]: - - 一致性检查结果枚举 - - 错误描述(如果不一致) """ node_data = await self._fetch_node(node_id, node_label, end_user_id) @@ -266,7 +242,6 @@ class AccessHistoryManager: access_count = node_data.get('access_count', 0) activation_value = node_data.get('activation_value') - # 检查1:access_history[-1] == last_access_time if access_history and last_access_time: if access_history[-1] != last_access_time: return ( @@ -275,7 +250,6 @@ class AccessHistoryManager: f"last_access_time={last_access_time}" ) - # 检查2:len(access_history) == access_count if len(access_history) != access_count: return ( ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT, @@ -283,14 +257,12 @@ class AccessHistoryManager: f"access_count={access_count}" ) - # 检查3:有访问历史必须有激活值 if access_history and activation_value is None: return ( ConsistencyCheckResult.MISSING_ACTIVATION, "Node has access_history but activation_value is None" ) - # 检查4:激活值范围 if activation_value is not None: offset = self.actr_calculator.offset if not (offset <= activation_value <= 1.0): @@ -301,30 +273,14 @@ class AccessHistoryManager: ) return ConsistencyCheckResult.CONSISTENT, None - + async def check_batch_consistency( self, node_label: str, end_user_id: Optional[str] = None, limit: int = 1000 ) -> Dict[str, Any]: - """ - 批量检查多个节点的一致性 - - Args: - node_label: 节点标签 - end_user_id: 组ID(可选) - limit: 检查的最大节点数 - - Returns: - Dict[str, Any]: 一致性检查报告,包含: - - total_checked: 检查的节点总数 - - consistent_count: 一致的节点数 - - inconsistent_count: 不一致的节点数 - - inconsistencies: 不一致节点的详细信息列表 - - consistency_rate: 一致性率(0-1) - """ - # 查询所有相关节点 + """批量检查多个节点的一致性""" query = f""" MATCH (n:{node_label}) WHERE n.access_history IS NOT NULL @@ -343,7 +299,6 @@ class AccessHistoryManager: results = await self.connector.execute_query(query, **params) node_ids = [r['id'] for r in results] - # 检查每个节点 inconsistencies = [] consistent_count = 0 @@ -382,32 +337,15 @@ class AccessHistoryManager: ) return report - + async def repair_inconsistency( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> bool: - """ - 自动修复节点的数据不一致问题 - - 修复策略: - 1. 如果access_history[-1] != last_access_time:使用access_history[-1] - 2. 如果len(access_history) != access_count:使用len(access_history) - 3. 如果有历史但无激活值:重新计算激活值 - 4. 如果激活值超出范围:重新计算激活值 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - bool: 修复成功返回True,否则返回False - """ + """自动修复节点的数据不一致问题""" try: - # 检查一致性 result, message = await self.check_consistency( node_id=node_id, node_label=node_label, @@ -418,7 +356,6 @@ class AccessHistoryManager: logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]") return True - # 获取节点数据 node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") @@ -427,17 +364,13 @@ class AccessHistoryManager: access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) - # 准备修复数据 repair_data = {} - # 修复last_access_time if access_history: repair_data['last_access_time'] = access_history[-1] - # 修复access_count repair_data['access_count'] = len(access_history) - # 修复activation_value if access_history: current_time = datetime.now() last_access_dt = datetime.fromisoformat(access_history[-1]) @@ -453,7 +386,6 @@ class AccessHistoryManager: ) repair_data['activation_value'] = activation_value - # 执行修复 query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -484,26 +416,16 @@ class AccessHistoryManager: f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}" ) return False - + # ==================== 私有辅助方法 ==================== - + async def _fetch_node( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: - """ - 获取节点数据 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Optional[Dict[str, Any]]: 节点数据,如果不存在返回None - """ + """获取节点数据""" query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -527,12 +449,13 @@ class AccessHistoryManager: if results: return results[0] return None - + async def _calculate_update( self, node_data: Dict[str, Any], current_time: datetime, - current_time_iso: str + current_time_iso: str, + access_times: int = 1 ) -> Dict[str, Any]: """ 计算更新数据 @@ -541,45 +464,40 @@ class AccessHistoryManager: node_data: 当前节点数据 current_time: 当前时间(datetime对象) current_time_iso: 当前时间(ISO格式字符串) + access_times: 本次访问次数(合并后可能大于1) Returns: - Dict[str, Any]: 更新数据,包含所有需要更新的字段 + Dict[str, Any]: 更新数据 """ - access_history = node_data.get('access_history') or [] - # Handle None importance_score - default to 0.5 importance_score = node_data.get('importance_score') if importance_score is None: importance_score = 0.5 - # 追加新的访问时间 - new_access_history = access_history + [current_time_iso] + # 本次新增的时间戳 + new_timestamps = [current_time_iso] * access_times - # 修剪访问历史(如果过长) - access_history_dt = [ - datetime.fromisoformat(ts) for ts in new_access_history - ] + # 仅用本次新增的访问记录计算激活值 + new_history_dt = [current_time] * access_times trimmed_history_dt = self.actr_calculator.trim_access_history( - access_history=access_history_dt, + access_history=new_history_dt, current_time=current_time ) - trimmed_history = [ts.isoformat() for ts in trimmed_history_dt] - # 计算新的激活值 activation_value = self.actr_calculator.calculate_memory_activation( access_history=trimmed_history_dt, current_time=current_time, - last_access_time=current_time, # 最后访问时间就是当前时间 + last_access_time=current_time, importance_score=importance_score ) - # 返回所有需要更新的字段 return { 'activation_value': activation_value, - 'access_history': trimmed_history, + 'new_timestamps': new_timestamps, + 'access_count_delta': access_times, + 'access_count': len(trimmed_history_dt), 'last_access_time': current_time_iso, - 'access_count': len(trimmed_history) } - + async def _atomic_update( self, node_id: str, @@ -588,10 +506,10 @@ class AccessHistoryManager: end_user_id: Optional[str] = None ) -> Dict[str, Any]: """ - 原子性更新节点(使用乐观锁) + 原子性更新节点(使用 APOC 原子操作) - 使用Neo4j事务和版本号确保所有字段同时更新或回滚。 - 实现乐观锁机制防止并发冲突。 + 使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全, + 无需 version 字段和乐观锁,数据库层面保证原子性。 Args: node_id: 节点ID @@ -603,126 +521,68 @@ class AccessHistoryManager: Dict[str, Any]: 更新后的节点数据 Raises: - RuntimeError: 如果更新失败或发生版本冲突 + RuntimeError: 如果更新失败 """ - # 定义事务函数 - async def update_transaction(tx, node_id, node_label, update_data, end_user_id): - # 步骤1:读取当前节点并获取版本号 - read_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - """ - if end_user_id: - read_query += " WHERE n.end_user_id = $end_user_id" - read_query += """ - RETURN n.id as id, - n.version as version, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score - """ + content_field_map = { + 'Statement': 'n.statement as statement', + 'MemorySummary': 'n.content as content', + 'ExtractedEntity': 'null as content_placeholder', + 'Community': 'n.summary as summary' + } + + if node_label not in content_field_map: + raise ValueError( + f"Unsupported node_label: {node_label}. " + f"Supported labels are: {list(content_field_map.keys())}" + ) + + content_field = content_field_map[node_label] + + where_clause = "" + if end_user_id: + where_clause = " AND n.end_user_id = $end_user_id" + + query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + WHERE true{where_clause} + CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count + WITH n + CALL (n) {{ + UNWIND $new_timestamps AS ts + CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue + RETURN count(*) AS inserted + }} + SET n.activation_value = $activation_value, + n.last_access_time = $last_access_time + RETURN n.id as id, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count, + n.importance_score as importance_score, + {content_field} + """ + + params = { + 'node_id': node_id, + 'access_count_delta': update_data['access_count_delta'], + 'new_timestamps': update_data['new_timestamps'], + 'activation_value': update_data['activation_value'], + 'last_access_time': update_data['last_access_time'], + } + if end_user_id: + params['end_user_id'] = end_user_id + + try: + results = await self.connector.execute_query(query, **params) - read_params = {'node_id': node_id} - if end_user_id: - read_params['end_user_id'] = end_user_id - - read_result = await tx.run(read_query, **read_params) - current_node = await read_result.single() - - if not current_node: + if not results: raise RuntimeError(f"Node not found: {node_label}[{node_id}]") - # 获取当前版本号(如果不存在则为0) - current_version = current_node.get('version', 0) or 0 - new_version = current_version + 1 - - # 步骤2:使用乐观锁更新节点 - # 根据节点类型构建完整的查询语句 - content_field_map = { - 'Statement': 'n.statement as statement', - 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 - } - - # 显式检查节点类型,不支持的类型抛出错误 - if node_label not in content_field_map: - raise ValueError( - f"Unsupported node_label: {node_label}. " - f"Supported labels are: {list(content_field_map.keys())}" - ) - - content_field = content_field_map[node_label] - - # 构建 WHERE 子句 - where_conditions = [] - if end_user_id: - where_conditions.append("n.end_user_id = $end_user_id") - - # 添加版本检查 - if current_version > 0: - where_conditions.append("n.version = $current_version") - else: - where_conditions.append("(n.version IS NULL OR n.version = 0)") - - where_clause = " AND ".join(where_conditions) if where_conditions else "true" - - # 构建完整的更新查询 - update_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - WHERE {where_clause} - SET n.activation_value = $activation_value, - n.access_history = $access_history, - n.last_access_time = $last_access_time, - n.access_count = $access_count, - n.version = $new_version - RETURN n.id as id, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score, - n.version as version, - {content_field} - """ - - update_params = { - 'node_id': node_id, - 'current_version': current_version, - 'new_version': new_version, - 'activation_value': update_data['activation_value'], - 'access_history': update_data['access_history'], - 'last_access_time': update_data['last_access_time'], - 'access_count': update_data['access_count'] - } - if end_user_id: - update_params['end_user_id'] = end_user_id - - update_result = await tx.run(update_query, **update_params) - updated_node = await update_result.single() - - if not updated_node: - raise RuntimeError( - f"Version conflict detected for {node_label}[{node_id}]. " - f"Expected version {current_version}, but node was modified by another transaction." - ) - - # 转换为字典并移除占位符字段 - result_dict = dict(updated_node) + result_dict = dict(results[0]) result_dict.pop('content_placeholder', None) return result_dict - - # 执行事务 - try: - result = await self.connector.execute_write_transaction( - update_transaction, - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - return result except Exception as e: logger.error( f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}" diff --git a/api/app/core/memory/storage_services/search/keyword_search.py b/api/app/core/memory/storage_services/search/keyword_search.py index d2591945..2458cf30 100644 --- a/api/app/core/memory/storage_services/search/keyword_search.py +++ b/api/app/core/memory/storage_services/search/keyword_search.py @@ -5,7 +5,7 @@ 使用Neo4j的全文索引进行高效的文本匹配。 """ -from typing import List, Dict, Any, Optional +from typing import List, Optional from app.core.logging_config import get_memory_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult @@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy): # 调用底层的关键词搜索函数 results_dict = await search_graph( connector=self.connector, - q=query_text, + query=query_text, end_user_id=end_user_id, limit=limit, include=include_list diff --git a/api/app/core/memory/utils/data/text_utils.py b/api/app/core/memory/utils/data/text_utils.py index d0b10f97..eaed0940 100644 --- a/api/app/core/memory/utils/data/text_utils.py +++ b/api/app/core/memory/utils/data/text_utils.py @@ -22,7 +22,9 @@ def escape_lucene_query(query: str) -> str: s = s.replace("\r", " ").replace("\n", " ").strip() # Lucene reserved tokens/special characters - specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':'] + # NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent + # TokenMgrError when the query contains unmatched slashes. + specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/'] # Replace longer tokens first to avoid partial double-escaping for token in sorted(specials, key=len, reverse=True): s = s.replace(token, f"\\{token}") diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 0cea98f2..a1ad885e 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -1,6 +1,6 @@ import os from jinja2 import Environment, FileSystemLoader - +from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering # Setup Jinja2 environment @@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt( predicate_instructions: dict = None, language: str = "zh", ontology_types: "OntologyTypeList | None" = None, + speaker: str = None, ) -> str: """ Renders the triplet extraction prompt using the extract_triplet.jinja2 template. @@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt( predicate_instructions: Optional predicate instructions language: The language to use for entity descriptions ("zh" for Chinese, "en" for English) ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification + speaker: Speaker role ("user" or "assistant") for the current statement Returns: Rendered prompt content as string @@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt( template = prompt_env.get_template("extract_triplet.jinja2") # 准备本体类型数据 - ontology_type_section = "" + ontology_type_section = None ontology_type_names = [] type_hierarchy_hints = [] if ontology_types and ontology_types.types: @@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt( ontology_types=ontology_type_section, ontology_type_names=ontology_type_names, type_hierarchy_hints=type_hierarchy_hints, + speaker=speaker, ) # 记录渲染结果到提示日志(与示例日志结构一致) log_prompt_rendering('triplet extraction', rendered_prompt) diff --git a/api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 index 3cdb5fd0..611bd6df 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 @@ -43,8 +43,9 @@ Each statement must be labeled as per the criteria mentioned below. 对话上下文和共指消解: - 将每个陈述句归属于说出它的参与者。 -- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。 -- 将所有代词解析为对话上下文中的具体人物或实体。 +- **对于用户的发言:必须使用"用户"作为主语**,禁止将"用户"或"我"替换为用户的真实姓名或别名。例如,用户说"我叫张三"应提取为"用户叫张三",而不是"张三叫张三"。 +- 对于 AI 助手的发言:使用"助手"或"AI助手"作为主语。 +- 将所有代词解析为对话上下文中的具体人物或实体,但"我"必须解析为"用户"。 - 识别并将抽象引用解析为其具体名称(如果提到)。 - 将缩写和首字母缩略词扩展为其完整形式。 {% else %} @@ -68,8 +69,9 @@ Context Resolution Requirements: Conversational Context & Co-reference Resolution: - Attribute every statement to the participant who uttered it. -- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户"). -- Resolve all pronouns to the specific person or entity from the conversation's context. +- **For user's statements: always use "用户" (User) as the subject**. Do NOT replace "用户" or "I" with the user's real name or alias. For example, if the user says "I'm John", extract as "用户 is John", not "John is John". +- For AI assistant's statements: use "助手" or "AI助手" as the subject. +- Resolve all pronouns to the specific person or entity from the conversation's context, but "I"/"我" must always resolve to "用户". - Identify and resolve abstract references to their specific names if mentioned. - Expand abbreviations and acronyms to their full form. {% endif %} @@ -139,13 +141,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合 示例输出: { "statements": [ { - "statement": "Sarah Chen 最近一直在尝试水彩画。", + "statement": "用户最近一直在尝试水彩画。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" }, { - "statement": "Sarah Chen 画了一些花朵。", + "statement": "用户画了一些花朵。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" @@ -157,13 +159,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合 "relevance": "IRRELEVANT" }, { - "statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。", + "statement": "用户认为她的水彩画中的色彩组合可以改进。", "statement_type": "OPINION", "temporal_type": "STATIC", "relevance": "RELEVANT" }, { - "statement": "Sarah Chen 真的很喜欢玫瑰和百合。", + "statement": "用户真的很喜欢玫瑰和百合。", "statement_type": "FACT", "temporal_type": "STATIC", "relevance": "RELEVANT" @@ -186,13 +188,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合 示例输出: { "statements": [ { - "statement": "张曼婷最近在尝试水彩画。", + "statement": "用户最近在尝试水彩画。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" }, { - "statement": "张曼婷画了一些花朵。", + "statement": "用户画了一些花朵。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" @@ -204,13 +206,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合 "relevance": "IRRELEVANT" }, { - "statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。", + "statement": "用户觉得水彩画的色彩搭配还有提升的空间。", "statement_type": "OPINION", "temporal_type": "STATIC", "relevance": "RELEVANT" }, { - "statement": "张曼婷很喜欢玫瑰和百合。", + "statement": "用户很喜欢玫瑰和百合。", "statement_type": "FACT", "temporal_type": "STATIC", "relevance": "RELEVANT" @@ -233,13 +235,13 @@ User: "I think the color combinations could use some improvement, but I really l Example Output: { "statements": [ { - "statement": "Sarah Chen has been trying watercolor painting recently.", + "statement": "用户 has been trying watercolor painting recently.", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" }, { - "statement": "Sarah Chen painted some flowers.", + "statement": "用户 painted some flowers.", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" @@ -251,13 +253,13 @@ Example Output: { "relevance": "IRRELEVANT" }, { - "statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.", + "statement": "用户 thinks the color combinations in her watercolor paintings could use some improvement.", "statement_type": "OPINION", "temporal_type": "STATIC", "relevance": "RELEVANT" }, { - "statement": "Sarah Chen really likes roses and lilies.", + "statement": "用户 really likes roses and lilies.", "statement_type": "FACT", "temporal_type": "STATIC", "relevance": "RELEVANT" @@ -280,13 +282,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合 Example Output: { "statements": [ { - "statement": "张曼婷最近在尝试水彩画。", + "statement": "用户最近在尝试水彩画。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" }, { - "statement": "张曼婷画了一些花朵。", + "statement": "用户画了一些花朵。", "statement_type": "FACT", "temporal_type": "DYNAMIC", "relevance": "RELEVANT" @@ -298,13 +300,13 @@ Example Output: { "relevance": "IRRELEVANT" }, { - "statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。", + "statement": "用户觉得水彩画的色彩搭配还有提升的空间。", "statement_type": "OPINION", "temporal_type": "STATIC", "relevance": "RELEVANT" }, { - "statement": "张曼婷很喜欢玫瑰和百合。", + "statement": "用户很喜欢玫瑰和百合。", "statement_type": "FACT", "temporal_type": "STATIC", "relevance": "RELEVANT" diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 6605532d..1a79b482 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -23,6 +23,16 @@ Extract entities and knowledge triplets from the given statement. ===Inputs=== **Chunk Content:** "{{ chunk_content }}" **Statement:** "{{ statement }}" +{% if speaker %} +**Speaker:** {{ speaker }} +{% if speaker == "assistant" %} +{% if language == "zh" %} +⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。 +{% else %} +⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants. +{% endif %} +{% endif %} +{% endif %} {% if ontology_types %} ===Ontology Type Guidance=== @@ -87,7 +97,17 @@ Extract entities and knowledge triplets from the given statement. * "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name) * "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name) - 空值:如果没有别名,使用 `[]` - - 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字 + - **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。** + - **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。** + - **🚨 说话人视角:当 speaker 为 assistant 时,AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases,绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。** + * "我叫陈思远,我给AI取名为远仔" → 用户 aliases=["陈思远"],AI助手 aliases=["远仔"] + * "我叫vv" → 用户 aliases=["vv"](没有给AI取名的表达,名字归用户) + * [speaker=assistant] "好的,VV" → 用户 aliases=["VV"](AI 在称呼用户,原文中出现了"VV") + * [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"](AI 在自我介绍,这是 AI 的别名) + * ❌ 错误:将"远仔"放入用户的 aliases("远仔"是给AI取的名字,不是用户的名字) + * ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases + * ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases + * ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉) {% else %} - Include: nicknames, full names, abbreviations, alternative names - Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST** @@ -96,7 +116,17 @@ Extract entities and knowledge triplets from the given statement. * "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name) * "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name) - Empty: If no aliases, use `[]` - - Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names + - **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.** + - **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.** + - **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.** + * "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"] + * "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user) + * [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text) + * [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias) + * ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user) + * ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases + * ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases + * ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants) {% endif %} @@ -122,7 +152,60 @@ Extract entities and knowledge triplets from the given statement. -4. **ALIASES ORDER:** +4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:** +{% if language == "zh" %} + - **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。** + - 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases + - AI/助手实体的 name 字段:使用 "AI助手" + - 用户给 AI 取的名字:放入 AI/助手实体的 aliases + - **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中** + - **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式** + - **🚨 "你不叫X了"/"你不叫X,你叫Y" 句式:X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"(AI)。** + - **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式** + - **🚨 speaker=assistant 时的特殊规则:** + * AI 用来称呼用户的名字 → 归入**用户**实体的 aliases(但必须是原文中逐字出现的称呼,不能推测) + * AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases + * 判断依据:AI 说"你叫X"或用 X 称呼用户 → X 是用户别名;AI 说"我叫X"或"我是X" → X 是 AI 别名 + - 示例: + * "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户) + * "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户) + * "我叫陈思远,我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["远仔"] + * "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"];AI实体: name="AI助手", aliases=["小助"] + * "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]("远仔"是AI旧名,"陈仔"是AI新名,都归AI。不要把"远仔"或"陈仔"放入用户的aliases) + * [speaker=assistant] "好的VV,今天想干点啥?" → 用户实体: name="用户", aliases=["VV"](AI 在称呼用户,原文中出现了"VV") + * [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["陈仔"] + * ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases(没有任何给 AI 取名的表达) + * ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases + * ❌ 错误:aliases=["陈思远", "远仔"]("远仔"是给AI取的名字,不是用户的名字) + * ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉) +{% else %} + - **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.** + - Only create an AI/assistant entity when the user **explicitly** names the AI/assistant + - AI/assistant entity name field: use "AI Assistant" + - Names the user gives to the AI: put in the AI/assistant entity's aliases + - **🚨 NEVER put names given to the AI into the user entity's aliases** + - **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns** + - **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).** + - **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns** + - **🚨 Special rules when speaker=assistant:** + * Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer) + * Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases + * Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias + - Examples: + * "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user) + * "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user) + * "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"] + * "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"] + * "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names) + * [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text) + * [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"] + * ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists) + * ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases + * ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user) + * ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants) +{% endif %} + +5. **ALIASES ORDER:** {% if language == "zh" %} - 顺序优先级:按出现顺序,先出现的在前 {% else %} @@ -202,8 +285,19 @@ Output: {"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false} ] } + +**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy" +Output: +{ + "triplets": [ + {"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"} + ], + "entities": [ + {"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false}, + {"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false} + ] +} {% else %} -**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre." Output: { "triplets": [ @@ -258,6 +352,39 @@ Output: ] } +**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远,我给AI取名为远仔" +Output: +{ + "triplets": [ + {"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"} + ], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false}, + {"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false} + ] +} + +**Example 7 (纯用户自我介绍,无AI命名 - Chinese):** "我叫vv" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false} + ] +} + +**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔" +Output: +{ + "triplets": [ + {"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"} + ], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false}, + {"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false} + ] +} + {% endif %} ===End of Examples=== @@ -279,4 +406,12 @@ Output: - **⚠️ ALIASES ORDER: preserve temporal order of appearance** - **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []** +**Output JSON structure:** +```json +{ + "triplets": [...], + "entities": [...] +} +``` + {{ json_schema }} diff --git a/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 new file mode 100644 index 00000000..5d019b12 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/extract_user_metadata.jinja2 @@ -0,0 +1,135 @@ +===Task=== +Extract user metadata from the following conversation statements spoken by the user. + +{% if language == "zh" %} +**"三度原则"判断标准:** +- 复用度:该信息是否会被多个功能模块使用? +- 约束度:该信息是否会影响系统行为? +- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。 + +**提取规则:** +- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息 +- 仅提取文本中明确提到的信息,不要推测 +- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象 +- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值) + +{% if existing_metadata %} +**重要:合并已有元数据** +下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**: +- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息 +- 如果用户提到了新信息,**添加**到对应字段中 +- 如果已有信息未被用户否定,**保留**在输出中 +- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值 +- 最终输出应该是完整的、合并后的元数据,不是增量 +{% endif %} + +**字段说明:** +- profile.role:用户的职业或角色,如 教师、医生、后端工程师 +- profile.domain:用户所在领域,如 教育、医疗、软件开发 +- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理 +- profile.interests:用户主动表达兴趣的话题或领域标签 +- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级) +- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨) +- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨) +- knowledge_tags:用户涉及的知识领域标签 + +**用户别名变更(增量模式):** +- **aliases_to_add**:本次新发现的用户别名,包括: + * 用户主动自我介绍:如"我叫张三"、"我的名字是XX"、"我的网名是XX" + * 他人对用户的称呼:如"同事叫我陈哥"、"大家叫我小张"、"领导叫我老陈" + * 只提取原文中逐字出现的名字,严禁推测或创造 + * 禁止提取:用户给 AI 取的名字、第三方人物自身的名字、"用户"/"我" 等占位词 + * 如果没有新别名,返回空数组 `[]` +- **aliases_to_remove**:用户明确否认的别名,包括: + * 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组 + * **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名 + * 例如:用户说"我不叫陈小刀了" → 只移除"陈小刀",不要移除"陈哥"、"老陈"等未被提及的别名 + * 如果没有要移除的别名,返回空数组 `[]` +{% if existing_aliases %} +- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复) +{% endif %} +{% else %} +**"Three-Degree Principle" criteria:** +- Reusability: Will this information be used by multiple functional modules? +- Constraint: Will this information affect system behavior? +- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information. + +**Extraction rules:** +- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user +- Only extract information explicitly mentioned in the text, do not speculate +- If no user profile information can be extracted, return an empty user_metadata object +- **Output language must match the input text language** + +{% if existing_metadata %} +**Important: Merge with existing metadata** +Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**: +- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output +- If the user mentions new info, **add** it to the corresponding field +- If existing info is not negated by the user, **keep** it in the output +- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing +- The final output should be the complete, merged metadata — not an incremental update +{% endif %} + +**Field descriptions:** +- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer +- profile.domain: User's domain, e.g. education, healthcare, software development +- profile.expertise: User's skills or tools (general, not limited to programming) +- profile.interests: Topics or domain tags the user actively expressed interest in +- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced) +- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive) +- behavioral_hints.tone_preference: Tone preference (casual/professional/academic) +- knowledge_tags: Knowledge domain tags related to the user + +**User alias changes (incremental mode):** +- **aliases_to_add**: Newly discovered user aliases from this conversation, including: + * User self-introductions: e.g. "I'm John", "My name is XX", "My username is XX" + * How others address the user: e.g. "My colleagues call me Johnny", "People call me Mike" + * Only extract names that appear VERBATIM in the text — never infer or fabricate + * Do NOT extract: names the user gives to the AI, third-party people's own names, placeholder words like "User"/"I" + * If no new aliases, return empty array `[]` +- **aliases_to_remove**: Aliases the user explicitly denies, including: + * User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array + * **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases + * Example: User says "I'm not called John anymore" → only remove "John", do NOT remove "Johnny", "J" or other related aliases not mentioned + * If no aliases to remove, return empty array `[]` +{% if existing_aliases %} +- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output) +{% endif %} +{% endif %} + +===User Statements=== +{% for stmt in statements %} +- {{ stmt }} +{% endfor %} + +{% if existing_metadata %} +===Existing User Metadata=== +```json +{{ existing_metadata | tojson }} +``` +{% endif %} + +===Output Format=== +Return a JSON object with the following structure: +```json +{ + "user_metadata": { + "profile": { + "role": "", + "domain": "", + "expertise": [], + "interests": [] + }, + "behavioral_hints": { + "learning_stage": "", + "preferred_depth": "", + "tone_preference": "" + }, + "knowledge_tags": [] + }, + "aliases_to_add": [], + "aliases_to_remove": [] +} +``` + +{{ json_schema }} diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 80117f27..7b570b47 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Dict, List, Optional, TypeVar from langchain_aws import ChatBedrock from langchain_community.chat_models import ChatTongyi @@ -9,11 +9,12 @@ from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLLM from langchain_ollama import OllamaLLM from langchain_openai import ChatOpenAI, OpenAI -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.models.models_model import ModelProvider, ModelType +from app.core.models.compatible_chat import CompatibleChatOpenAI T = TypeVar("T") @@ -24,7 +25,11 @@ class RedBearModelConfig(BaseModel): provider: str api_key: str base_url: Optional[str] = None + capability: List[str] = Field(default_factory=list) # 模型能力列表,驱动所有能力开关 is_omni: bool = False # 是否为 Omni 模型 + deep_thinking: bool = False # 是否启用深度思考模式 + thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算 + json_output: bool = False # 是否强制 JSON 输出 # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 @@ -32,6 +37,23 @@ class RedBearModelConfig(BaseModel): concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + @model_validator(mode="after") + def _resolve_capabilities(self) -> "RedBearModelConfig": + from app.core.logging_config import get_business_logger + logger = get_business_logger() + if self.deep_thinking and "thinking" not in self.capability: + logger.warning( + f"模型 {self.model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking" + ) + self.deep_thinking = False + self.thinking_budget_tokens = None + if self.json_output and "json_output" not in self.capability: + logger.warning( + f"模型 {self.model_name} 不支持 JSON 输出(capability 中无 'json_output'),已自动关闭 json_output" + ) + self.json_output = False + return self + class RedBearModelFactory: """模型工厂类""" @@ -44,7 +66,7 @@ class RedBearModelFactory: # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() - logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}") + logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}") # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: @@ -58,7 +80,7 @@ class RedBearModelFactory: write=60.0, pool=10.0, ) - return { + params: Dict[str, Any] = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -66,6 +88,26 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + is_streaming = bool(config.extra_params.get("streaming")) + if is_streaming: + params["stream_usage"] = True + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: + extra_body = params.setdefault("extra_body", {}) + if config.deep_thinking: + extra_body["enable_thinking"] = False + if is_streaming: + extra_body["enable_thinking"] = True + if config.thinking_budget_tokens: + extra_body["thinking_budget"] = config.thinking_budget_tokens + params["extra_body"] = extra_body + # JSON 输出模式 + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + model_kwargs["response_format"] = {"type": "json_object"} + params["model_kwargs"] = model_kwargs + return params if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -78,7 +120,7 @@ class RedBearModelFactory: write=60.0, # 写入超时:60秒 pool=10.0, # 连接池超时:10秒 ) - return { + params: Dict[str, Any] = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -86,16 +128,56 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + is_streaming = bool(config.extra_params.get("streaming")) + if is_streaming: + params["stream_usage"] = True + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: + # VOLCANO 深度思考仅流式支持 + if provider == ModelProvider.VOLCANO: + thinking_config: Dict[str, Any] = {"type": "enabled" if config.deep_thinking else "disabled"} + if config.deep_thinking and config.thinking_budget_tokens: + thinking_config["budget_tokens"] = config.thinking_budget_tokens + params["extra_body"] = {"thinking": thinking_config} + else: + extra_body = params.setdefault("extra_body", {}) + if config.deep_thinking: + extra_body["enable_thinking"] = False + if is_streaming: + extra_body["enable_thinking"] = True + if config.thinking_budget_tokens: + extra_body["thinking_budget"] = config.thinking_budget_tokens + params["extra_body"] = extra_body + # JSON 输出模式 + if config.json_output: + params.setdefault("model_kwargs", {}) + params["model_kwargs"]["response_format"] = {"type": "json_object"} + return params elif provider == ModelProvider.DASHSCOPE: - # DashScope (通义千问) 使用自己的参数格式 - # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 - # 只支持: model, dashscope_api_key, max_retries, client - return { + params = { "model": config.model_name, "dashscope_api_key": config.api_key, "max_retries": config.max_retries, **config.extra_params } + # 支持 thinking 的模型始终传 enable_thinking,关闭时显式传 False 避免模型默认开启思考 + if "thinking" in config.capability: + is_streaming = bool(config.extra_params.get("streaming")) + model_kwargs = params.setdefault("model_kwargs", {}) + if config.deep_thinking: + model_kwargs["enable_thinking"] = False + if is_streaming: + model_kwargs["enable_thinking"] = True + model_kwargs["incremental_output"] = True + if config.thinking_budget_tokens: + model_kwargs["thinking_budget"] = config.thinking_budget_tokens + params["model_kwargs"] = model_kwargs + if config.json_output: + model_kwargs = params.setdefault("model_kwargs", {}) + model_kwargs["response_format"] = {"type": "json_object"} + params["model_kwargs"] = model_kwargs + return params elif provider == ModelProvider.BEDROCK: # Bedrock 使用 AWS 凭证 # api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id @@ -134,6 +216,17 @@ class RedBearModelFactory: elif "region_name" not in params: params["region_name"] = "us-east-1" # 默认区域 + # 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型 + # 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项) + if config.deep_thinking: + budget = config.thinking_budget_tokens or 10000 + params["additional_model_request_fields"] = { + "thinking": {"type": "enabled", "budget_tokens": budget} + } + # JSON 输出模式 + if config.json_output: + params.setdefault("model_kwargs", {}) + params["model_kwargs"]["response_format"] = {"type": "json_object"} return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) @@ -145,10 +238,15 @@ class RedBearModelFactory: if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: return { "model": config.model_name, - # "base_url": config.base_url, "jina_api_key": config.api_key, **config.extra_params } + elif provider == ModelProvider.DASHSCOPE: + return { + "model": config.model_name, + "dashscope_api_key": config.api_key, + **config.extra_params + } else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) @@ -157,10 +255,12 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - # dashscope 的 omni 模型使用 OpenAI 兼容模式 + # dashscope的omni模型 和 volcano模型使用 if provider == ModelProvider.DASHSCOPE and config.is_omni: - return ChatOpenAI - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]: + return CompatibleChatOpenAI + if provider == ModelProvider.VOLCANO: + return CompatibleChatOpenAI + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if type == ModelType.LLM: return OpenAI elif type == ModelType.CHAT: @@ -202,6 +302,9 @@ def get_provider_rerank_class(provider: str): if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_community.document_compressors import JinaRerank return JinaRerank + elif provider == ModelProvider.DASHSCOPE: + from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank + return DashScopeRerank # elif provider == ModelProvider.OLLAMA: # from langchain_ollama import OllamaEmbeddings # return OllamaEmbeddings diff --git a/api/app/core/models/compatible_chat.py b/api/app/core/models/compatible_chat.py new file mode 100644 index 00000000..114a3567 --- /dev/null +++ b/api/app/core/models/compatible_chat.py @@ -0,0 +1,52 @@ +""" +火山引擎 ChatOpenAI 扩展 + +ChatOpenAI 在解析流式 SSE 时只取 delta.content,会丢弃 delta.reasoning_content。 +此类仅重写 _convert_chunk_to_generation_chunk,将 reasoning_content 补入 additional_kwargs。 +""" +from __future__ import annotations + +from typing import Any, Optional, Union + +from langchain_core.outputs import ChatGenerationChunk, ChatResult +from langchain_openai import ChatOpenAI + + +class CompatibleChatOpenAI(ChatOpenAI): + """火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。""" + + def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult: + result = super()._create_chat_result(response, generation_info) + # 将非流式响应中的 reasoning_content 补入 additional_kwargs + choices = response.choices if hasattr(response, "choices") else response.get("choices", []) + if choices: + message = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {}) + reasoning = ( + getattr(message, "reasoning_content", None) + or (message.get("reasoning_content") if isinstance(message, dict) else None) + ) + if reasoning and result.generations: + result.generations[0].message.additional_kwargs["reasoning_content"] = reasoning + return result + + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: type, + base_generation_info: Optional[dict], + ) -> Optional[ChatGenerationChunk]: + gen_chunk = super()._convert_chunk_to_generation_chunk( + chunk, default_chunk_class, base_generation_info + ) + if gen_chunk is None: + return None + + # 从原始 chunk 中提取 reasoning_content + choices = chunk.get("choices") or chunk.get("chunk", {}).get("choices", []) + if choices: + delta = choices[0].get("delta") or {} + reasoning: Any = delta.get("reasoning_content") + if reasoning: + gen_chunk.message.additional_kwargs["reasoning_content"] = reasoning + + return gen_chunk diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py index 3269e1d0..991e4498 100644 --- a/api/app/core/models/embedding.py +++ b/api/app/core/models/embedding.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from langchain_core.embeddings import Embeddings from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory @@ -22,11 +22,38 @@ class RedBearEmbeddings(Embeddings): self._model = self._create_model(config) self._client = None - def _create_model(self, config: RedBearModelConfig) -> Embeddings: + @staticmethod + def _create_model(config: RedBearModelConfig) -> Embeddings: """根据配置创建 LangChain 模型""" embedding_class = get_provider_embedding_class(config.provider) - model_params = RedBearModelFactory.get_model_params(config) - return embedding_class(**model_params) + provider = config.provider.lower() + # Embedding models only need connection params, never LLM-specific ones + # (e.g. enable_thinking, model_kwargs) — build params directly. + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: + import httpx + params = { + "model": config.model_name, + "base_url": config.base_url, + "api_key": config.api_key, + "timeout": httpx.Timeout(timeout=config.timeout, connect=60.0), + "max_retries": config.max_retries + } + elif provider == ModelProvider.DASHSCOPE: + params = { + "model": config.model_name, + "dashscope_api_key": config.api_key, + "max_retries": config.max_retries, + } + elif provider == ModelProvider.OLLAMA: + params = { + "model": config.model_name, + "base_url": config.base_url, + } + elif provider == ModelProvider.BEDROCK: + params = RedBearModelFactory.get_model_params(config) + else: + params = RedBearModelFactory.get_model_params(config) + return embedding_class(**params) def _create_volcano_client(self, config: RedBearModelConfig): """创建火山引擎客户端""" diff --git a/api/app/core/models/rerank.py b/api/app/core/models/rerank.py index c4b91e25..45b6fc88 100644 --- a/api/app/core/models/rerank.py +++ b/api/app/core/models/rerank.py @@ -76,5 +76,9 @@ class RedBearRerank(BaseDocumentCompressor): from langchain_community.document_compressors import JinaRerank model_instance: JinaRerank = self._model return model_instance.rerank(documents=documents, query=query, top_n=top_n) + elif provider == ModelProvider.DASHSCOPE: + from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank + model_instance: DashScopeRerank = self._model + return model_instance.rerank(documents=documents, query=query, top_n=top_n) else: raise ValueError(f"不支持的模型提供商: {provider}") diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index 2c0ab757..f96dba15 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -6,11 +6,13 @@ models: description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 logo: bedrock + - name: amazon nova type: llm provider: bedrock @@ -19,6 +21,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -27,6 +30,7 @@ models: - stream-tool-call - vision logo: bedrock + - name: anthropic claude type: llm provider: bedrock @@ -35,6 +39,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -44,13 +50,15 @@ models: - stream-tool-call - document logo: bedrock + - name: cohere type: llm provider: bedrock description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -58,6 +66,7 @@ models: - tool-call - stream-tool-call logo: bedrock + - name: deepseek type: llm provider: bedrock @@ -66,6 +75,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -74,39 +85,45 @@ models: - tool-call - stream-tool-call logo: bedrock + - name: meta type: llm provider: bedrock description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought - tool-call logo: bedrock + - name: mistral type: llm provider: bedrock description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought - tool-call logo: bedrock + - name: openai type: llm provider: bedrock description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -114,13 +131,15 @@ models: - tool-call - stream-tool-call logo: bedrock + - name: qwen type: llm provider: bedrock description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -128,6 +147,7 @@ models: - tool-call - stream-tool-call logo: bedrock + - name: amazon.rerank-v1:0 type: rerank provider: bedrock @@ -139,6 +159,7 @@ models: tags: - 重排序模型 logo: bedrock + - name: cohere.rerank-v3-5:0 type: rerank provider: bedrock @@ -150,6 +171,7 @@ models: tags: - 重排序模型 logo: bedrock + - name: amazon.nova-2-multimodal-embeddings-v1:0 type: embedding provider: bedrock @@ -163,6 +185,7 @@ models: - 文本嵌入模型 - vision logo: bedrock + - name: amazon.titan-embed-text-v1 type: embedding provider: bedrock @@ -174,6 +197,7 @@ models: tags: - 文本嵌入模型 logo: bedrock + - name: amazon.titan-embed-text-v2:0 type: embedding provider: bedrock @@ -185,6 +209,7 @@ models: tags: - 文本嵌入模型 logo: bedrock + - name: cohere.embed-english-v3 type: embedding provider: bedrock @@ -196,6 +221,7 @@ models: tags: - 文本嵌入模型 logo: bedrock + - name: cohere.embed-multilingual-v3 type: embedding provider: bedrock diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index 89a16966..9b45f107 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -6,91 +6,109 @@ models: description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-r1-distill-qwen-32b type: llm provider: dashscope description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-r1 type: llm provider: dashscope description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-v3.1 type: llm provider: dashscope description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-v3.2-exp type: llm provider: dashscope description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-v3.2 type: llm provider: dashscope description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: deepseek-v3 type: llm provider: dashscope description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - agent-thought logo: dashscope + - name: farui-plus type: llm provider: dashscope description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -98,13 +116,15 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: glm-4.7 type: llm provider: dashscope description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -112,6 +132,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qvq-max-latest type: llm provider: dashscope @@ -119,7 +140,9 @@ models: is_deprecated: false is_official: true capability: - - vision + - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -127,6 +150,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qvq-max type: llm provider: dashscope @@ -134,7 +158,9 @@ models: is_deprecated: false is_official: true capability: - - vision + - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -142,6 +168,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-coder-turbo-0919 type: llm provider: dashscope @@ -155,13 +182,16 @@ models: - 代码模型 - agent-thought logo: dashscope + - name: qwen-max-latest type: llm provider: dashscope description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -169,6 +199,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-max-longcontext type: llm provider: dashscope @@ -183,13 +214,15 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-max type: llm provider: dashscope description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -197,6 +230,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-mt-plus type: llm provider: dashscope @@ -210,6 +244,7 @@ models: - 翻译模型 - agent-thought logo: dashscope + - name: qwen-mt-turbo type: llm provider: dashscope @@ -223,6 +258,7 @@ models: - 翻译模型 - agent-thought logo: dashscope + - name: qwen-plus-0112 type: llm provider: dashscope @@ -237,6 +273,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-0125 type: llm provider: dashscope @@ -251,6 +288,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-0723 type: llm provider: dashscope @@ -265,6 +303,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-0806 type: llm provider: dashscope @@ -279,6 +318,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-0919 type: llm provider: dashscope @@ -293,6 +333,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-1125 type: llm provider: dashscope @@ -307,6 +348,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-1127 type: llm provider: dashscope @@ -321,6 +363,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-plus-1220 type: llm provider: dashscope @@ -335,6 +378,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen-vl-max type: chat provider: dashscope @@ -342,8 +386,9 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - json_output is_omni: false tags: - 大语言模型 @@ -352,6 +397,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen-vl-plus-0809 type: chat provider: dashscope @@ -359,8 +405,8 @@ models: is_deprecated: true is_official: true capability: - - vision - - video + - vision + - video is_omni: false tags: - 大语言模型 @@ -369,6 +415,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen-vl-plus-2025-01-02 type: chat provider: dashscope @@ -376,8 +423,8 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video is_omni: false tags: - 大语言模型 @@ -386,6 +433,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen-vl-plus-2025-01-25 type: chat provider: dashscope @@ -393,8 +441,8 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video is_omni: false tags: - 大语言模型 @@ -403,6 +451,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen-vl-plus-latest type: chat provider: dashscope @@ -410,8 +459,9 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - json_output is_omni: false tags: - 大语言模型 @@ -420,6 +470,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen-vl-plus type: chat provider: dashscope @@ -427,8 +478,9 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - json_output is_omni: false tags: - 大语言模型 @@ -437,13 +489,15 @@ models: - agent-thought - video logo: dashscope + - name: qwen2.5-0.5b-instruct type: llm provider: dashscope description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -451,13 +505,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-14b type: llm provider: dashscope description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -465,13 +522,15 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-235b-a22b-instruct-2507 type: llm provider: dashscope description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -479,13 +538,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-235b-a22b-thinking-2507 type: llm provider: dashscope description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -493,13 +555,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-235b-a22b type: llm provider: dashscope description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -507,13 +572,15 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-30b-a3b-instruct-2507 type: llm provider: dashscope description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -521,13 +588,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-30b-a3b type: llm provider: dashscope description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -535,13 +605,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-32b type: llm provider: dashscope description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -549,13 +622,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-4b type: llm provider: dashscope description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -563,13 +639,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-8b type: llm provider: dashscope description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -577,65 +656,78 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-coder-30b-a3b-instruct type: llm provider: dashscope description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - 代码模型 - agent-thought logo: dashscope + - name: qwen3-coder-480b-a35b-instruct type: llm provider: dashscope description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 - 代码模型 - agent-thought logo: dashscope + - name: qwen3-coder-plus-2025-09-23 type: llm provider: dashscope description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - 代码模型 - agent-thought logo: dashscope + - name: qwen3-coder-plus type: llm provider: dashscope description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - 代码模型 - agent-thought logo: dashscope + - name: qwen3-max-2025-09-23 type: llm provider: dashscope description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -644,13 +736,16 @@ models: - stream-tool-call - 联网搜索 logo: dashscope + - name: qwen3-max-2026-01-23 type: llm provider: dashscope description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -659,13 +754,16 @@ models: - stream-tool-call - 联网搜索 logo: dashscope + - name: qwen3-max-preview type: llm provider: dashscope description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -673,13 +771,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-max type: llm provider: dashscope description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -688,13 +789,15 @@ models: - stream-tool-call - 联网搜索 logo: dashscope + - name: qwen3-next-80b-a3b-instruct type: llm provider: dashscope description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -702,13 +805,16 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-next-80b-a3b-thinking type: llm provider: dashscope description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -716,6 +822,7 @@ models: - agent-thought - stream-tool-call logo: dashscope + - name: qwen3-omni-flash-2025-12-01 type: llm provider: dashscope @@ -723,9 +830,11 @@ models: is_deprecated: false is_official: true capability: - - vision - - video - - audio + - vision + - video + - audio + - thinking + - json_output is_omni: true tags: - 大语言模型 @@ -735,6 +844,7 @@ models: - video - audio logo: dashscope + - name: qwen3-vl-235b-a22b-instruct type: chat provider: dashscope @@ -742,8 +852,9 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - json_output is_omni: false tags: - 大语言模型 @@ -754,6 +865,7 @@ models: - vision - video logo: dashscope + - name: qwen3-vl-235b-a22b-thinking type: chat provider: dashscope @@ -761,8 +873,10 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -773,6 +887,7 @@ models: - vision - video logo: dashscope + - name: qwen3-vl-30b-a3b-instruct type: chat provider: dashscope @@ -780,8 +895,9 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - json_output is_omni: false tags: - 大语言模型 @@ -792,6 +908,7 @@ models: - vision - video logo: dashscope + - name: qwen3-vl-30b-a3b-thinking type: chat provider: dashscope @@ -799,8 +916,10 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -811,6 +930,7 @@ models: - vision - video logo: dashscope + - name: qwen3-vl-flash type: chat provider: dashscope @@ -818,8 +938,10 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -830,6 +952,7 @@ models: - vision - video logo: dashscope + - name: qwen3-vl-plus-2025-09-23 type: chat provider: dashscope @@ -837,8 +960,10 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -847,6 +972,7 @@ models: - agent-thought - video logo: dashscope + - name: qwen3-vl-plus type: chat provider: dashscope @@ -854,8 +980,10 @@ models: is_deprecated: false is_official: true capability: - - vision - - video + - vision + - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -864,45 +992,55 @@ models: - agent-thought - video logo: dashscope + - name: qwq-32b type: llm provider: dashscope description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought - stream-tool-call logo: dashscope + - name: qwq-plus-0305 type: llm provider: dashscope description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought - stream-tool-call logo: dashscope + - name: qwq-plus type: llm provider: dashscope description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 - agent-thought - stream-tool-call logo: dashscope + - name: gte-rerank-v2 type: rerank provider: dashscope @@ -914,6 +1052,7 @@ models: tags: - 重排序模型 logo: dashscope + - name: gte-rerank type: rerank provider: dashscope @@ -925,6 +1064,7 @@ models: tags: - 重排序模型 logo: dashscope + - name: multimodal-embedding-v1 type: embedding provider: dashscope @@ -932,13 +1072,14 @@ models: is_deprecated: false is_official: true capability: - - vision + - vision is_omni: false tags: - 嵌入模型 - 多模态模型 - vision logo: dashscope + - name: text-embedding-v1 type: embedding provider: dashscope @@ -951,6 +1092,7 @@ models: - 嵌入模型 - 文本嵌入 logo: dashscope + - name: text-embedding-v2 type: embedding provider: dashscope @@ -963,6 +1105,7 @@ models: - 嵌入模型 - 文本嵌入 logo: dashscope + - name: text-embedding-v3 type: embedding provider: dashscope @@ -975,6 +1118,7 @@ models: - 嵌入模型 - 文本嵌入 logo: dashscope + - name: text-embedding-v4 type: embedding provider: dashscope @@ -986,4 +1130,4 @@ models: tags: - 嵌入模型 - 文本嵌入 - logo: dashscope \ No newline at end of file + logo: dashscope diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 7f6d3a51..1c0a0b2d 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -10,6 +10,7 @@ models: - vision - audio - video + - json_output is_omni: true tags: - 大语言模型 @@ -20,13 +21,15 @@ models: - audio - video logo: openai + - name: gpt-3.5-turbo-0125 type: llm provider: openai description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -34,13 +37,15 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-3.5-turbo-1106 type: llm provider: openai description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -48,13 +53,15 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-3.5-turbo-16k type: llm provider: openai description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -62,6 +69,7 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-3.5-turbo-instruct type: llm provider: openai @@ -73,13 +81,15 @@ models: tags: - 大语言模型 logo: openai + - name: gpt-3.5-turbo type: llm provider: openai description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -87,13 +97,15 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-4-0125-preview type: llm provider: openai description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -101,13 +113,15 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-4-1106-preview type: llm provider: openai description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -115,6 +129,7 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-4-turbo-2024-04-09 type: llm provider: openai @@ -123,6 +138,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -131,13 +147,15 @@ models: - stream-tool-call - vision logo: openai + - name: gpt-4-turbo-preview type: llm provider: openai description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -145,6 +163,7 @@ models: - agent-thought - stream-tool-call logo: openai + - name: gpt-4-turbo type: llm provider: openai @@ -153,6 +172,7 @@ models: is_official: true capability: - vision + - json_output is_omni: false tags: - 大语言模型 @@ -161,6 +181,7 @@ models: - stream-tool-call - vision logo: openai + - name: o1-preview type: llm provider: openai @@ -173,6 +194,7 @@ models: - 大语言模型 - agent-thought logo: openai + - name: o1 type: llm provider: openai @@ -181,6 +203,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -190,6 +214,7 @@ models: - vision - structured-output logo: openai + - name: o3-2025-04-16 type: llm provider: openai @@ -198,6 +223,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -207,13 +234,16 @@ models: - stream-tool-call - structured-output logo: openai + - name: o3-mini-2025-01-31 type: llm provider: openai description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -222,13 +252,16 @@ models: - stream-tool-call - structured-output logo: openai + - name: o3-mini type: llm provider: openai description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true - capability: [] + capability: + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -237,6 +270,7 @@ models: - stream-tool-call - structured-output logo: openai + - name: o3-pro-2025-06-10 type: llm provider: openai @@ -245,6 +279,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -253,6 +289,7 @@ models: - vision - structured-output logo: openai + - name: o3-pro type: llm provider: openai @@ -261,6 +298,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -269,6 +308,7 @@ models: - vision - structured-output logo: openai + - name: o3 type: llm provider: openai @@ -277,6 +317,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -286,6 +328,7 @@ models: - stream-tool-call - structured-output logo: openai + - name: o4-mini-2025-04-16 type: llm provider: openai @@ -294,6 +337,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -303,6 +348,7 @@ models: - stream-tool-call - structured-output logo: openai + - name: o4-mini type: llm provider: openai @@ -311,6 +357,8 @@ models: is_official: true capability: - vision + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -320,6 +368,7 @@ models: - stream-tool-call - structured-output logo: openai + - name: text-embedding-3-large type: embedding provider: openai @@ -331,6 +380,7 @@ models: tags: - 文本向量模型 logo: openai + - name: text-embedding-3-small type: embedding provider: openai @@ -342,6 +392,7 @@ models: tags: - 文本向量模型 logo: openai + - name: text-embedding-ada-002 type: embedding provider: openai diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml index 24609f5a..6658c2f9 100644 --- a/api/app/core/models/scripts/volcano_models.yaml +++ b/api/app/core/models/scripts/volcano_models.yaml @@ -10,6 +10,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -24,6 +26,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -38,6 +42,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -52,6 +58,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -68,6 +76,7 @@ models: capability: - vision - video + - json_output is_omni: false tags: - 大语言模型 @@ -82,6 +91,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -96,6 +107,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -110,6 +123,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -124,6 +139,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -139,6 +156,8 @@ models: capability: - vision - video + - thinking + - json_output is_omni: false tags: - 大语言模型 @@ -166,7 +185,8 @@ models: description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 @@ -178,7 +198,8 @@ models: description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 is_deprecated: false is_official: true - capability: [] + capability: + - json_output is_omni: false tags: - 大语言模型 diff --git a/api/app/core/quota_manager.py b/api/app/core/quota_manager.py new file mode 100644 index 00000000..6c02ac7a --- /dev/null +++ b/api/app/core/quota_manager.py @@ -0,0 +1,473 @@ +""" +统一配额管理器 - 社区版和 SaaS 版共用 + +配额来源策略: +1. 优先从 premium 模块的 tenant_subscriptions 表读取(SaaS 版) +2. 降级到 default_free_plan.py 配置文件(社区版兜底) +""" +import asyncio +import time +from functools import wraps +from typing import Optional, Callable, Dict, Any +from uuid import UUID + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.core.logging_config import get_auth_logger +from app.i18n.exceptions import QuotaExceededError + +logger = get_auth_logger() + + +def _get_user_from_kwargs(kwargs: dict): + """从 kwargs 中获取 user 对象""" + for key in ["user", "current_user"]: + if key in kwargs: + return kwargs[key] + return None + + +def _get_tenant_id_from_kwargs(db: Session, kwargs: dict): + """从 kwargs 中获取 tenant_id""" + user = _get_user_from_kwargs(kwargs) + if user and hasattr(user, 'tenant_id'): + return user.tenant_id + + workspace_id = kwargs.get("workspace_id") + if workspace_id: + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if workspace: + return workspace.tenant_id + + api_key_auth = kwargs.get("api_key_auth") + if api_key_auth and hasattr(api_key_auth, 'workspace_id'): + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == api_key_auth.workspace_id).first() + if workspace: + return workspace.tenant_id + + data = kwargs.get("data") or kwargs.get("body") or kwargs.get("payload") + if data and hasattr(data, "workspace_id"): + from app.models.workspace_model import Workspace + workspace = db.query(Workspace).filter(Workspace.id == data.workspace_id).first() + if workspace: + return workspace.tenant_id + + return None + + +def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]: + """ + 获取租户的配额配置 + + 优先级: + 1. premium 模块的 tenant_subscriptions(SaaS 版) + 2. default_free_plan.py 配置文件(社区版兜底) + """ + # 尝试从 premium 模块获取 + try: + from premium.platform_admin.package_plan_service import TenantSubscriptionService + quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id) + if quota_config: + logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置") + return quota_config + except (ModuleNotFoundError, ImportError, Exception) as e: + logger.debug(f"无法从 premium 模块获取配额配置: {e}") + + # 降级到配置文件 + try: + from app.config.default_free_plan import DEFAULT_FREE_PLAN + logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}") + return DEFAULT_FREE_PLAN.get("quotas") + except Exception as e: + logger.error(f"无法从配置文件获取配额: {e}") + return None + + +class QuotaUsageRepository: + """配额使用量数据访问层""" + + def __init__(self, db: Session): + self.db = db + + def count_workspaces(self, tenant_id: UUID) -> int: + from app.models.workspace_model import Workspace + return self.db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active.is_(True) + ).count() + + def count_apps(self, tenant_id: UUID) -> int: + from app.models.app_model import App + from app.models.workspace_model import Workspace + return self.db.query(App).join( + Workspace, App.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id, + App.is_active.is_(True) + ).count() + + def count_skills(self, tenant_id: UUID) -> int: + from app.models.skill_model import Skill + return self.db.query(Skill).filter( + Skill.tenant_id == tenant_id, + Skill.is_active.is_(True) + ).count() + + def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float: + from app.models.document_model import Document + from app.models.knowledge_model import Knowledge + from app.models.workspace_model import Workspace + result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join( + Knowledge, Document.kb_id == Knowledge.id + ).join( + Workspace, Knowledge.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id, + Document.status == 1, + ).scalar() + return float(result) / (1024 ** 3) if result else 0.0 + + def count_memory_engines(self, tenant_id: UUID) -> int: + from app.models.memory_config_model import MemoryConfig + from app.models.workspace_model import Workspace + return self.db.query(MemoryConfig).join( + Workspace, MemoryConfig.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def count_end_users(self, tenant_id: UUID) -> int: + from app.models.end_user_model import EndUser + from app.models.workspace_model import Workspace + return self.db.query(EndUser).join( + Workspace, EndUser.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def count_models(self, tenant_id: UUID) -> int: + from app.models.models_model import ModelConfig + return self.db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_active == True + ).count() + + def count_ontology_projects(self, tenant_id: UUID) -> int: + from app.models.ontology_scene import OntologyScene + from app.models.workspace_model import Workspace + return self.db.query(OntologyScene).join( + Workspace, OntologyScene.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id + ).count() + + def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str): + """按配额类型分发,返回当前使用量""" + dispatch = { + "workspace_quota": self.count_workspaces, + "app_quota": self.count_apps, + "skill_quota": self.count_skills, + "knowledge_capacity_quota": self.sum_knowledge_capacity_gb, + "memory_engine_quota": self.count_memory_engines, + "end_user_quota": self.count_end_users, + "model_quota": self.count_models, + "ontology_project_quota": self.count_ontology_projects, + } + fn = dispatch.get(quota_type) + return fn(tenant_id) if fn else 0 + + +def _check_quota( + db: Session, + tenant_id: UUID, + quota_type: str, + resource_name: str, + usage_func: Optional[Callable] = None, +) -> None: + """核心配额检查逻辑:对比使用量和配额限制""" + try: + quota_config = _get_quota_config(db, tenant_id) + if not quota_config: + logger.warning(f"租户 {tenant_id} 无有效配额配置,跳过配额检查") + return + + quota_limit = quota_config.get(quota_type) + if quota_limit is None: + logger.warning(f"配额配置未包含 {quota_type},跳过配额检查") + return + + if usage_func: + current_usage = usage_func(db, tenant_id) + else: + current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type) + + if current_usage >= quota_limit: + logger.warning( + f"配额不足: tenant={tenant_id}, type={quota_type}, " + f"usage={current_usage}, limit={quota_limit}" + ) + raise QuotaExceededError( + resource=resource_name, + current_usage=current_usage, + quota_limit=quota_limit, + ) + + logger.debug( + f"配额检查通过: tenant={tenant_id}, type={quota_type}, " + f"usage={current_usage}, limit={quota_limit}" + ) + + except QuotaExceededError: + raise + except Exception as e: + logger.error( + f"配额检查异常: tenant={tenant_id}, type={quota_type}, " + f"error_type={type(e).__name__}, error={str(e)}", + exc_info=True, + ) + raise + + +# ─── 具名装饰器 ──────────────────────────────────────────────────────────── + +def check_workspace_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "workspace_quota", "workspace") + return func(*args, **kwargs) + return wrapper + + +def check_skill_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "skill_quota", "skill") + return func(*args, **kwargs) + return wrapper + + +def check_app_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "app_quota", "app") + return func(*args, **kwargs) + return wrapper + + +def check_knowledge_capacity_quota(func: Callable) -> Callable: + @wraps(func) + async def async_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return await func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return await func(*args, **kwargs) + _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + return func(*args, **kwargs) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + +def check_memory_engine_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine") + return func(*args, **kwargs) + return wrapper + + +def check_end_user_quota(func: Callable) -> Callable: + @wraps(func) + async def async_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return await func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return await func(*args, **kwargs) + _check_quota(db, tenant_id, "end_user_quota", "end_user") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + if not db: + logger.warning("配额检查失败:缺少 db 参数") + return func(*args, **kwargs) + tenant_id = _get_tenant_id_from_kwargs(db, kwargs) + if not tenant_id: + logger.warning("配额检查失败:无法获取 tenant_id") + return func(*args, **kwargs) + _check_quota(db, tenant_id, "end_user_quota", "end_user") + return func(*args, **kwargs) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + +def check_ontology_project_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project") + return func(*args, **kwargs) + return wrapper + + +def check_model_quota(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, "model_quota", "model") + return func(*args, **kwargs) + return wrapper + + +def check_model_activation_quota(func: Callable) -> Callable: + """模型激活时的配额检查装饰器""" + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + + model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) + model_data = kwargs.get("model_data") + + if not model_id or not model_data: + logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") + return func(*args, **kwargs) + + if model_data.is_active is True: + try: + from app.models.models_model import ModelConfig + from app.services.model_service import ModelConfigService + + existing_model = ModelConfigService.get_model_by_id( + db=db, + model_id=model_id, + tenant_id=user.tenant_id + ) + + if not existing_model.is_active: + logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}") + _check_quota(db, user.tenant_id, "model_quota", "model") + except Exception as e: + logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}") + raise + + return func(*args, **kwargs) + return wrapper + + +def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None): + """通用配额检查装饰器,支持自定义使用量获取函数""" + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.warning("配额检查失败:缺少 db 或 user 参数") + return func(*args, **kwargs) + _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) + return func(*args, **kwargs) + return wrapper + return decorator + + +# ─── 配额使用统计 ──────────────────────────────────────────────────────────── + +def get_quota_usage(db: Session, tenant_id: UUID) -> dict: + """获取租户所有配额的使用情况""" + quota_config = _get_quota_config(db, tenant_id) + if not quota_config: + return {} + + repo = QuotaUsageRepository(db) + + def pct(used, limit): + return round(used / limit * 100, 1) if limit else None + + workspace_count = repo.count_workspaces(tenant_id) + skill_count = repo.count_skills(tenant_id) + app_count = repo.count_apps(tenant_id) + knowledge_gb = repo.sum_knowledge_capacity_gb(tenant_id) + memory_count = repo.count_memory_engines(tenant_id) + end_user_count = repo.count_end_users(tenant_id) + model_count = repo.count_models(tenant_id) + ontology_count = repo.count_ontology_projects(tenant_id) + + api_ops_current = 0 + try: + from app.core.config import settings + import redis + _now = time.time() + _rk = f"rate_limit:tenant_qps:{tenant_id}" + _r = redis.StrictRedis( + host=settings.REDIS_HOST, port=settings.REDIS_PORT, + db=settings.REDIS_DB, password=settings.REDIS_PASSWORD, + decode_responses=True + ) + api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf")) + except Exception: + pass + + return { + "workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))}, + "skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))}, + "app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))}, + "knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"}, + "memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))}, + "end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))}, + "ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))}, + "model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))}, + "api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"}, + } diff --git a/api/app/core/quota_stub.py b/api/app/core/quota_stub.py new file mode 100644 index 00000000..577dfadb --- /dev/null +++ b/api/app/core/quota_stub.py @@ -0,0 +1,36 @@ +""" +配额检查 stub - 社区版和 SaaS 版统一使用 core.quota_manager 实现 + +所有配额检查逻辑统一在 core 层实现,两个版本共用: +- 社区版:从 default_free_plan.py 读取配额限制 +- SaaS 版:优先从 tenant_subscriptions 表读取,降级到配置文件 +""" +from app.core.quota_manager import ( + check_workspace_quota, + check_skill_quota, + check_app_quota, + check_knowledge_capacity_quota, + check_memory_engine_quota, + check_end_user_quota, + check_ontology_project_quota, + check_model_quota, + check_model_activation_quota, + get_quota_usage, + _check_quota, + QuotaUsageRepository, +) + +__all__ = [ + "check_workspace_quota", + "check_skill_quota", + "check_app_quota", + "check_knowledge_capacity_quota", + "check_memory_engine_quota", + "check_end_user_quota", + "check_ontology_project_quota", + "check_model_quota", + "check_model_activation_quota", + "get_quota_usage", + "_check_quota", + "QuotaUsageRepository", +] diff --git a/api/app/core/rag/app/naive.py b/api/app/core/rag/app/naive.py index 72272347..312216dd 100644 --- a/api/app/core/rag/app/naive.py +++ b/api/app/core/rag/app/naive.py @@ -672,10 +672,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, excel_parser = ExcelParser() if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true": sections = [(_, "") for _ in excel_parser.html(binary, 12) if _] - parser_config["chunk_token_num"] = 0 else: sections = [(_, "") for _ in excel_parser(binary) if _] - parser_config["chunk_token_num"] = 12800 + callback(0.8, "Finish parsing.") + # Excel 每行直接作为一个 chunk,不经过 naive_merge 避免被 delimiter 拆分 + chunks = [s for s, _ in sections] + res.extend(tokenize_chunks(chunks, doc, is_english, None)) + res.extend(embed_res) + res.extend(url_res) + return res elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") diff --git a/api/app/core/rag/deepdoc/parser/excel_parser.py b/api/app/core/rag/deepdoc/parser/excel_parser.py index d66a21a8..c3999be9 100644 --- a/api/app/core/rag/deepdoc/parser/excel_parser.py +++ b/api/app/core/rag/deepdoc/parser/excel_parser.py @@ -232,14 +232,14 @@ class RAGExcelParser: t = str(ti[i].value) if i < len(ti) else "" t += (":" if t else "") + str(c.value) fields.append(t) - line = "; ".join(fields) + line = "\n".join(fields) if sheetname.lower().find("sheet") < 0: - line += " ——" + sheetname + line += "\n——" + sheetname res.append(line) else: # 只有表头的情况 if header_fields: - line = "; ".join(header_fields) + line = "\n".join(header_fields) if sheetname.lower().find("sheet") < 0: line += " ——" + sheetname res.append(line) diff --git a/api/app/core/rag/deepdoc/parser/mineru_parser.py b/api/app/core/rag/deepdoc/parser/mineru_parser.py index fe6178ec..c2f7af16 100644 --- a/api/app/core/rag/deepdoc/parser/mineru_parser.py +++ b/api/app/core/rag/deepdoc/parser/mineru_parser.py @@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser): self.page_from = page_from self.page_to = page_to try: - with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: - self.pdf = pdf - self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] + with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁 + with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: + self.pdf = pdf + self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] except Exception as e: self.page_images = None self.total_page = 0 diff --git a/api/app/core/rag/llm/embedding_model.py b/api/app/core/rag/llm/embedding_model.py index 22e35a15..59210054 100644 --- a/api/app/core/rag/llm/embedding_model.py +++ b/api/app/core/rag/llm/embedding_model.py @@ -50,7 +50,9 @@ class OpenAIEmbed(Base): def encode(self, texts: list): # OpenAI requires batch size <=16 batch_size = 16 - texts = [truncate(t, 8191) for t in texts] + # Use 8000 instead of 8191 to leave safety margin for tokenizer differences + # between cl100k_base (used by truncate) and the actual embedding model + texts = [truncate(t, 8000) for t in texts] ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): @@ -63,7 +65,7 @@ class OpenAIEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) + res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True}) return np.array(res.data[0].embedding), self.total_token_count(res) @@ -79,6 +81,7 @@ class LocalAIEmbed(Base): def encode(self, texts: list): batch_size = 16 + texts = [truncate(t, 8000) for t in texts] ress = [] for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) @@ -173,6 +176,7 @@ class XinferenceEmbed(Base): def encode(self, texts: list): batch_size = 16 + texts = [truncate(t, 8000) for t in texts] ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): @@ -188,7 +192,7 @@ class XinferenceEmbed(Base): def encode_queries(self, text): res = None try: - res = self.client.embeddings.create(input=[text], model=self.model_name) + res = self.client.embeddings.create(input=[truncate(text, 8000)], model=self.model_name) return np.array(res.data[0].embedding), self.total_token_count(res) except Exception as _e: log_exception(_e, res) diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index db93bc48..61540ee4 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +from app.services.model_service import ModelApiKeyService import logging logger = logging.getLogger(__name__) @@ -114,9 +115,8 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: try: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) except Exception as rerank_error: - # If reranker fails, log warning and continue with original results logger.warning( "Reranker failed, falling back to original results", extra={ @@ -132,7 +132,10 @@ def knowledge_retrieval( from app.core.rag.common.settings import kg_retriever doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) if doc: - all_results.insert(0, doc) + all_results.insert(0, DocumentChunk( + page_content=doc.get("page_content", ""), + metadata=doc.get("metadata", {}) + )) except Exception as graph_error: print(f"Failed to retrieve from knowledge graph: {str(graph_error)}") @@ -198,16 +201,18 @@ def _retrieve_for_knowledge( workspace_ids.append(str(db_knowledge.workspace_id)) if not chat_model: + llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id) chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base, + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base, ) if not embedding_model: + emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id) embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base, + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base, ) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -248,6 +253,29 @@ def _retrieve_for_knowledge( seen_ids.add(doc.metadata["doc_id"]) unique_rs.append(doc) rs = unique_rs + if unique_rs: + rs = vector_service.rerank( + query=kb_config["query"], + docs=unique_rs, + top_k=kb_config["top_k"] + ) + if kb_config["retrieve_type"] == "graph": + try: + from app.core.rag.common.settings import kg_retriever + graph_doc = kg_retriever.retrieval( + question=kb_config["query"], + workspace_ids=[str(db_knowledge.workspace_id)], + kb_ids=[str(db_knowledge.id)], + emb_mdl=embedding_model, + llm=chat_model, + ) + if graph_doc: + rs.insert(0, DocumentChunk( + page_content=graph_doc.get("page_content", ""), + metadata=graph_doc.get("metadata", {}) + )) + except Exception as graph_error: + logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}") results.extend(rs) return results, chat_model, embedding_model diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 00004dfe..2fda6b8b 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "convert_timezone", "timestamp_to_datetime", "now"] + enum=["format", "convert_timezone", "timestamp_to_datetime", "now", "datetime_to_timestamp"] ), ToolParameter( name="input_value", @@ -230,7 +230,7 @@ class DateTimeTool(BuiltinTool): @staticmethod def _datetime_to_timestamp(kwargs) -> dict: """日期时间转时间戳""" - input_value = kwargs.get("input_value") + input_value = kwargs.get("input_value").strip() input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") timezone_str = kwargs.get("from_timezone", "Asia/Shanghai") @@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool): return { "datetime": input_value, "timezone": timezone_str, - "timestamp": int(dt.timestamp()), + "timestamp": int(dt.timestamp()) * 1000, "iso_format": dt.isoformat(), - "result_data": int(dt.timestamp()) + "result_data": int(dt.timestamp()) * 1000 } def _calculate_datetime(self, kwargs) -> dict: diff --git a/api/app/core/tools/builtin/openclaw_tool.py b/api/app/core/tools/builtin/openclaw_tool.py new file mode 100644 index 00000000..2ff3a626 --- /dev/null +++ b/api/app/core/tools/builtin/openclaw_tool.py @@ -0,0 +1,300 @@ +"""OpenClaw 远程 Agent 内置工具""" +import time +import base64 +from io import BytesIO +from typing import List, Dict, Any, Optional +import aiohttp + +from app.core.tools.builtin.base import BuiltinTool +from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class OpenClawTool(BuiltinTool): + """OpenClaw 远程 Agent 工具 — 支持文本和图片多模态输入""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + super().__init__(tool_id, config) + params = self.parameters_config + + # 用户配置项(前端表单填写) + self._server_url = params.get("server_url", "") + self._api_key = params.get("api_key", "") + self._agent_id = params.get("agent_id", "main") + + # 内部默认值 + self._model = "openclaw" + self._session_strategy = "by_user" + self._timeout = 120 + + # 运行时上下文(通过 set_runtime_context 注入) + self._user_id = "anonymous" + self._conversation_id = None + self._uploaded_files = [] + + @property + def name(self) -> str: + return "openclaw_tool" + + @property + def description(self) -> str: + return ( + "OpenClaw 远程 Agent:将任务委托给远程 OpenClaw Agent。" + "具备 3D 模型生成与打印控制、设备管理、文件处理、浏览器自动化、" + "Shell 命令执行、网络搜索等能力。支持文本和图片多模态交互。" + ) + + def get_required_config_parameters(self) -> List[str]: + return ["server_url", "api_key"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="任务类型", + required=True, + enum= ["print_task", "device_query", "image_understand", "general"] + ), + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw Agent 的文本请求内容", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的图片 URL 或 base64 data URI(OpenClaw 支持图片输入)", + required=False + ) + ] + + # ---------- 运行时上下文注入 ---------- + def set_runtime_context( + self, + user_id: str = "anonymous", + conversation_id: Optional[str] = None, + uploaded_files: Optional[list] = None + ): + """注入运行时上下文(由 chat service 调用)""" + self._user_id = user_id + self._conversation_id = conversation_id + self._uploaded_files = uploaded_files or [] + + # ---------- 连接测试 ---------- + async def test_connection(self) -> Dict[str, Any]: + """测试 OpenClaw Gateway 连接""" + if not self._server_url: + return {"success": False, "message": "未配置 server_url"} + if not self._api_key: + return {"success": False, "message": "未配置 api_key"} + + url = f"{self._server_url.rstrip('/')}/v1/responses" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "x-openclaw-agent-id": self._agent_id + } + body = { + "model": self._model, + "user": "connection-test", + "input": "hi", + "stream": False + } + try: + timeout_cfg = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.post(url, json=body, headers=headers) as resp: + if resp.status < 400: + return {"success": True, "message": "OpenClaw 连接成功"} + error_text = await resp.text() + return { + "success": False, + "message": f"OpenClaw HTTP {resp.status}: {error_text[:200]}" + } + except Exception as e: + return {"success": False, "message": f"OpenClaw 连接失败: {str(e)}"} + + # ---------- 执行 ---------- + async def execute(self, **kwargs) -> ToolResult: + """执行 OpenClaw 调用""" + start_time = time.time() + try: + message = kwargs.get("message", "") + if not message: + return ToolResult.error_result( + error="message 参数不能为空", + error_code="OPENCLAW_INVALID_INPUT", + execution_time=time.time() - start_time + ) + + # 提取图片:优先从用户上传文件中获取,LLM 传的 image_url 作为兜底 + image_url = self._extract_image_from_uploads() + if not image_url: + image_url = kwargs.get("image_url") + if image_url and not image_url.startswith("data:"): + image_url = await self._download_and_encode_image(image_url) + + # 构建请求 + url = f"{self._server_url.rstrip('/')}/v1/responses" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "x-openclaw-agent-id": self._agent_id + } + user_field = ( + f"conv-{self._conversation_id}" + if self._session_strategy == "by_conversation" and self._conversation_id + else f"user-{self._user_id}" + ) + input_field = self._build_input(message, image_url) + body = { + "model": self._model, + "user": user_field, + "input": input_field, + "stream": False + } + + timeout_cfg = aiohttp.ClientTimeout(total=self._timeout) + # 打印请求日志(截断 base64 避免日志过大) + log_body = {**body} + if isinstance(log_body.get("input"), list): + log_body["input"] = "[multimodal input, truncated]" + elif isinstance(log_body.get("input"), str) and len(log_body["input"]) > 500: + log_body["input"] = log_body["input"][:500] + "..." + logger.info( + f"OpenClaw 请求: url={url}, agent_id={self._agent_id}, " + f"has_image={bool(image_url)}, body={log_body}" + ) + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.post(url, json=body, headers=headers) as resp: + execution_time = time.time() - start_time + if resp.status >= 400: + error_text = await resp.text() + return ToolResult.error_result( + error=f"OpenClaw HTTP {resp.status}: {error_text[:500]}", + error_code="OPENCLAW_HTTP_ERROR", + execution_time=execution_time + ) + data = await resp.json() + text = self._extract_response(data) + display_text = self._format_result(text) + return ToolResult.success_result( + data=display_text, + execution_time=execution_time + ) + + except aiohttp.ClientError as e: + return ToolResult.error_result( + error=f"OpenClaw 网络连接失败: {str(e)}", + error_code="OPENCLAW_NETWORK_ERROR", + execution_time=time.time() - start_time + ) + except Exception as e: + return ToolResult.error_result( + error=f"OpenClaw 调用失败: {str(e)}", + error_code="OPENCLAW_EXECUTION_ERROR", + execution_time=time.time() - start_time + ) + + # ---------- 私有方法 ---------- + def _extract_image_from_uploads(self) -> Optional[str]: + """从用户上传文件中提取图片 URL""" + for f in self._uploaded_files: + f_type = f.get("type", "") + if f_type == "image": + source = f.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/jpeg") + data = source.get("data", "") + return f"data:{media_type};base64,{data}" + elif f.get("image"): + return f.get("image") + elif f.get("url"): + return f.get("url") + elif f_type == "image_url": + return f.get("image_url", {}).get("url", "") + return None + + async def _download_and_encode_image(self, image_url: str) -> str: + """下载图片并转为 base64 data URI""" + try: + from PIL import Image + MAX_RAW_SIZE = 4 * 1024 * 1024 + + async with aiohttp.ClientSession() as session: + async with session.get( + image_url, allow_redirects=True, + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status != 200: + return image_url + content_type = resp.headers.get("Content-Type", "image/jpeg") + if not content_type.startswith("image/"): + return image_url + img_bytes = await resp.read() + + if len(img_bytes) > MAX_RAW_SIZE: + img = Image.open(BytesIO(img_bytes)) + if img.mode in ("RGBA", "P", "LA"): + img = img.convert("RGB") + if max(img.size) > 2048: + img.thumbnail((2048, 2048), Image.LANCZOS) + buf = BytesIO() + img.save(buf, format="JPEG", quality=75, optimize=True) + img_bytes = buf.getvalue() + content_type = "image/jpeg" + + b64 = base64.b64encode(img_bytes).decode("utf-8") + return f"data:{content_type};base64,{b64}" + except Exception as e: + logger.warning(f"OpenClaw 下载图片失败,使用原始 URL: {e}") + return image_url + + def _build_input(self, message: str, image_url: Optional[str] = None): + """构造请求 input 字段:有图片则构造多模态结构,否则纯文本""" + if not image_url: + return message + + content_parts = [{"type": "input_text", "text": message}] + if image_url.startswith("data:"): + try: + header, data = image_url.split(",", 1) + media_type = header.split(":")[1].split(";")[0] + content_parts.append({ + "type": "input_image", + "source": {"type": "base64", "media_type": media_type, "data": data} + }) + except (ValueError, IndexError): + return message + else: + content_parts.append({ + "type": "input_image", + "source": {"type": "url", "url": image_url} + }) + + return [{"type": "message", "role": "user", "content": content_parts}] + + def _extract_response(self, response_data: Dict[str, Any]) -> str: + """从 OpenClaw 响应中提取文本内容 + + OpenClaw /v1/responses 只返回 output_text 类型的内容。 + 图片信息(如有)由 OpenClaw Skill 以 Markdown 链接形式嵌入文本中返回。 + """ + output = response_data.get("output", []) + texts = [] + for item in output: + if item.get("type") == "message": + for content in item.get("content", []): + if content.get("type") == "output_text" and content.get("text"): + texts.append(content["text"]) + return "\n".join(texts) if texts else str(response_data) + + @staticmethod + def _format_result(text: str) -> str: + """格式化结果为 LLM 可读字符串""" + return text or "(OpenClaw 返回了空内容)" diff --git a/api/app/core/tools/builtin/operation_tool.py b/api/app/core/tools/builtin/operation_tool.py index 126541a8..e8b7c77e 100644 --- a/api/app/core/tools/builtin/operation_tool.py +++ b/api/app/core/tools/builtin/operation_tool.py @@ -11,6 +11,11 @@ class OperationTool(BaseTool): self.base_tool = base_tool self.operation = operation super().__init__(base_tool.tool_id, base_tool.config) + + def set_runtime_context(self, **kwargs): + """转发运行时上下文到 base_tool""" + if hasattr(self.base_tool, 'set_runtime_context'): + self.base_tool.set_runtime_context(**kwargs) @property def name(self) -> str: @@ -32,6 +37,8 @@ class OperationTool(BaseTool): return self._get_datetime_params() elif self.base_tool.name == 'json_tool': return self._get_json_params() + elif self.base_tool.name == 'openclaw_tool': + return self._get_openclaw_params() else: # 默认返回除operation外的所有参数 return [p for p in self.base_tool.parameters if p.name != "operation"] @@ -138,6 +145,29 @@ class OperationTool(BaseTool): default="Asia/Shanghai" ) ] + elif self.operation == "datetime_to_timestamp": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串,如:2026-04-07 10:30:25)", + required=True + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="from_timezone", + type=ParameterType.STRING, + description="源时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ) + ] else: return [] @@ -209,6 +239,64 @@ class OperationTool(BaseTool): else: return base_params + def _get_openclaw_params(self) -> List[ToolParameter]: + """获取 openclaw_tool 特定操作的参数""" + if self.operation == "print_task": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型", + required=False + ) + ] + elif self.operation == "device_query": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的设备查询指令", + required=True + ) + ] + elif self.operation == "image_understand": + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="要分析的图片 URL 或 base64 data URI", + required=False + ) + ] + else: + # general 及其他 + return [ + ToolParameter( + name="message", + type=ParameterType.STRING, + description="发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求", + required=True + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="可选,附带的图片 URL 或 base64 data URI", + required=False + ) + ] + async def execute(self, **kwargs) -> ToolResult: """执行特定操作""" # 添加operation参数 diff --git a/api/app/core/tools/configs/builtin/openclaw_tool.json b/api/app/core/tools/configs/builtin/openclaw_tool.json new file mode 100644 index 00000000..7c1f9629 --- /dev/null +++ b/api/app/core/tools/configs/builtin/openclaw_tool.json @@ -0,0 +1,15 @@ +{ + "name": "openclaw_tool", + "description": "调用OpenClaw Agent远程服务", + "tool_class": "OpenClawTool", + "category": "agent", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "server_url": "", + "api_key": "", + "agent_id": "main" + }, + "tags": ["agent", "openclaw", "multimodal", "3d-printing", "builtin"] +} diff --git a/api/app/core/tools/configs/builtin_tools.json b/api/app/core/tools/configs/builtin_tools.json index 79206a5e..882a970a 100644 --- a/api/app/core/tools/configs/builtin_tools.json +++ b/api/app/core/tools/configs/builtin_tools.json @@ -30,5 +30,18 @@ "parameters": { "api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true} } + }, + "openclaw": { + "name": "OpenClaw远程Agent", + "description": "OpenClaw Agent远程服务", + "tool_class": "OpenClawTool", + "category": "agent", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "server_url": {"type": "string", "description": "OpenClaw Gateway 地址", "required": true}, + "api_key": {"type": "string", "description": "OpenClaw API Key", "sensitive": true, "required": true} + } } } \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index 3dfe4c93..c03fe206 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -30,7 +30,7 @@ class CustomTool(BaseTool): self.auth_config = config.get("auth_config", {}) self.base_url = config.get("base_url", "") self.timeout = config.get("timeout", 30) - + # 解析schema self._parsed_operations = self._parse_openapi_schema() diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index 51415732..859b6312 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -131,7 +131,7 @@ class LangchainAdapter: def _tool_supports_operations(tool: BaseTool) -> bool: """检查工具是否支持多操作""" # 内置工具中支持操作的工具 - builtin_operation_tools = ['datetime_tool', 'json_tool'] + builtin_operation_tools = ['datetime_tool', 'json_tool', 'openclaw_tool'] # 检查内置工具 if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools: diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index b437d021..3539d33a 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -99,7 +99,7 @@ class SimpleMCPClient: # 建立 SSE 连接 response = await self._session.get(self.server_url) - if response.status not in (200, 202): + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") @@ -190,9 +190,7 @@ class SimpleMCPClient: try: async with self._session.post(self._endpoint_url, json=request) as response: - # MCP SSE 协议:POST 请求返回 200 或 202 均为正常 - # 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回 - if response.status not in (200, 202): + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") @@ -207,7 +205,7 @@ class SimpleMCPClient: raise MCPConnectionError("endpoint URL 未初始化") async with self._session.post(self._endpoint_url, json=notification) as response: - if response.status not in (200, 202): + if not (200 <= response.status < 300): logger.warning(f"通知发送失败: {response.status}") async def _initialize_modelscope_session(self): @@ -225,7 +223,7 @@ class SimpleMCPClient: try: async with self._session.post(self.server_url, json=init_request) as response: - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py index 2e24d085..41090983 100644 --- a/api/app/core/workflow/adapters/base_adapter.py +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -40,6 +40,7 @@ class WorkflowParserResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) + features: dict[str, Any] = Field(default_factory=dict) warnings: list[ExceptionDefinition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list) @@ -51,6 +52,7 @@ class WorkflowImportResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) + features: dict[str, Any] = Field(default_factory=dict) warnings: list[ExceptionDefinition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list) diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 4fa9508b..ad9312e1 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -15,7 +15,7 @@ from app.core.workflow.adapters.errors import ( ExceptionType ) from app.core.workflow.nodes.assigner.config import AssignmentItem -from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig +from app.core.workflow.nodes.base_config import VariableDefinition as NodeVariableDefinition, BaseNodeConfig from app.core.workflow.nodes.code.config import InputVariable, OutputVariable from app.core.workflow.nodes.configs import ( StartNodeConfig, @@ -32,13 +32,17 @@ from app.core.workflow.nodes.configs import ( NoteNodeConfig, ParameterExtractorNodeConfig, QuestionClassifierNodeConfig, - VariableAggregatorNodeConfig + VariableAggregatorNodeConfig, + ListOperatorNodeConfig, + DocExtractorNodeConfig, ) +from app.schemas.workflow_schema import VariableDefinition as SchemaVariableDefinition from app.core.workflow.nodes.cycle_graph.config import ( ConditionDetail as LoopConditionDetail, ConditionsConfig, CycleVariable ) +from app.core.workflow.nodes.list_operator.config import FilterCondition from app.core.workflow.nodes.enums import ( ValueInputType, ComparisonOperator, @@ -90,9 +94,12 @@ class DifyConverter(BaseConverter): NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config, NodeType.TOOL: self.convert_tool_node_config, NodeType.NOTES: self.convert_notes_config, + NodeType.LIST_OPERATOR: self.convert_list_operator_node_config, + NodeType.DOCUMENT_EXTRACTOR: self.convert_document_extractor_node_config, NodeType.CYCLE_START: lambda x: {}, NodeType.BREAK: lambda x: {}, } + self._file_vars_to_conv: list[SchemaVariableDefinition] = [] def get_node_convert(self, node_type): func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {}) @@ -126,7 +133,7 @@ class DifyConverter(BaseConverter): selector = var_selector.split('.') if len(selector) not in [2, 3] and var_selector != "context": raise Exception(f"invalid variable selector: {var_selector}") - if len(selector) == 3: + if len(selector) == 3 and selector[0] in ("conversation", "sys"): selector = selector[1:] if selector[0] == "conversation": selector[0] = "conv" @@ -213,7 +220,9 @@ class DifyConverter(BaseConverter): "end with": ComparisonOperator.END_WITH, "not contains": ComparisonOperator.NOT_CONTAINS, "exists": ComparisonOperator.NOT_EMPTY, - "not exists": ComparisonOperator.EMPTY + "not exists": ComparisonOperator.EMPTY, + "in": ComparisonOperator.IN, + "not in": ComparisonOperator.NOT_IN, } return operator_map.get(operator, operator) @@ -279,19 +288,25 @@ class DifyConverter(BaseConverter): ) continue - if var_type in ["file", "array[file]"]: - self.errors.append( - ExceptionDefinition( - type=ExceptionType.VARIABLE, - node_id=node["id"], - node_name=node_data["title"], - name=var["variable"], - detail=f"Unsupported Variable type for start node: {var_type}" - ) - ) + if var_type in [VariableType.FILE, VariableType.ARRAY_FILE]: + # 开始节点不支持文件变量,转为会话变量 + self._file_vars_to_conv.append(SchemaVariableDefinition( + name=var["variable"], + type=var_type.value, + required=var.get("required", False), + default=None, + description=var.get("label", ""), + )) + self.warnings.append(ExceptionDefinition( + type=ExceptionType.VARIABLE, + node_id=node["id"], + node_name=node_data["title"], + name=var["variable"], + detail=f"File variable '{var['variable']}' is not supported in start node, moved to conversation variables" + )) continue - var_def = VariableDefinition( + var_def = NodeVariableDefinition( name=var["variable"], type=var_type, required=var["required"], @@ -476,11 +491,11 @@ class DifyConverter(BaseConverter): node_data = node["data"] result = IterationNodeConfig.model_construct( input=self._process_list_variable_literal(node_data["iterator_selector"]), - parallel=node_data["is_parallel"], - parallel_count=node_data["parallel_nums"], + parallel=node_data.get("is_parallel", False), + parallel_count=node_data.get("parallel_nums", 4), output=self._process_list_variable_literal(node_data["output_selector"]), output_type=self.variable_type_map(node_data.get("output_type")), - flatten=node_data["flatten_output"], + flatten=node_data.get("flatten_output", False), ).model_dump() self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result) @@ -489,7 +504,23 @@ class DifyConverter(BaseConverter): def convert_assigner_node_config(self, node: dict) -> dict: node_data = node["data"] assignments = [] - for assignment in node_data["items"]: + + # Support both formats: + # 1. New format: node_data["items"] list + # 2. Flat format: assigned_variable_selector + input_variable_selector + write_mode + if "items" in node_data: + raw_items = node_data["items"] + elif "assigned_variable_selector" in node_data and "input_variable_selector" in node_data: + raw_items = [{ + "variable_selector": node_data["assigned_variable_selector"], + "value": node_data["input_variable_selector"], + "input_type": ValueInputType.VARIABLE, + "operation": node_data.get("write_mode", "over-write"), + }] + else: + raw_items = [] + + for assignment in raw_items: if assignment.get("operation") is None or assignment.get("value") is None: continue assignments.append( @@ -771,3 +802,119 @@ class DifyConverter(BaseConverter): show_author=node_data.get("showAuthor", True) ).model_dump() return result + + def convert_list_operator_node_config(self, node: dict) -> dict: + """Dify list-operator — convert variable path array to {{ }} selector format.""" + node_data = node["data"] + variable_path = node_data.get("variable", []) + input_list = self._process_list_variable_literal(variable_path) or "" + filter_by = node_data.get("filter_by", {"enabled": False, "conditions": []}) + # Convert each condition's comparison_operator from Dify format to native + if filter_by.get("conditions"): + converted_conditions = [] + for cond in filter_by["conditions"]: + converted_conditions.append({ + **cond, + "comparison_operator": self.convert_compare_operator( + cond.get("comparison_operator", "") + ) + }) + filter_by = {**filter_by, "conditions": converted_conditions} + result = { + "input_list": input_list, + "filter_by": filter_by, + "order_by": node_data.get("order_by", {"enabled": False, "key": "", "value": "asc"}), + "limit": node_data.get("limit", {"enabled": False, "size": -1}), + "extract_by": node_data.get("extract_by", {"enabled": False, "serial": "1"}), + } + self.config_validate(node["id"], node["data"]["title"], ListOperatorNodeConfig, result) + return result + + def convert_document_extractor_node_config(self, node: dict) -> dict: + """Convert Dify document-extractor node to MemoryBear DocExtractorNodeConfig. + + Dify document-extractor data fields: + variable_selector: list[str] - file variable path + """ + node_data = node["data"] + file_selector = self._process_list_variable_literal( + node_data.get("variable_selector", []) + ) or "" + result = DocExtractorNodeConfig.model_construct( + file_selector=file_selector, + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], DocExtractorNodeConfig, result) + return result + + @staticmethod + def convert_features(features: dict) -> dict: + """Convert Dify features to MemoryBear FeaturesConfigForm format.""" + if not features: + return {} + + result: dict = {} + + # opening_statement + opening = features.get("opening_statement", "") + suggested = features.get("suggested_questions", []) + result["opening_statement"] = { + "enabled": bool(opening), + "statement": opening or None, + "suggested_questions": suggested, + } + + # citation (对应 Dify retriever_resource) + retriever = features.get("retriever_resource", {}) + result["citation"] = { + "enabled": retriever.get("enabled", False) if isinstance(retriever, dict) else False, + } + + # file_upload: Dify allowed_file_types 数组 -> 前端扁平字段 + file_upload = features.get("file_upload", {}) + allowed_types = file_upload.get("allowed_file_types", []) if file_upload else [] + allowed_methods = file_upload.get("allowed_file_upload_methods", ["local_file", "remote_url"]) + if isinstance(allowed_methods, list): + if len(allowed_methods) >= 2: + transfer_method = "both" + elif allowed_methods: + transfer_method = allowed_methods[0] + else: + transfer_method = "both" + else: + transfer_method = allowed_methods or "both" + + file_config = file_upload.get("fileUploadConfig", {}) + result["file_upload"] = { + "enabled": file_upload.get("enabled", False) if file_upload else False, + "image_enabled": "image" in allowed_types, + "image_max_size_mb": file_config.get("image_file_size_limit", 10) if file_config else 10, + "image_allowed_extensions": ["png", "jpg", "jpeg"], + "audio_enabled": "audio" in allowed_types, + "audio_max_size_mb": file_config.get("audio_file_size_limit", 50) if file_config else 50, + "audio_allowed_extensions": ["mp3", "wav", "m4a"], + "document_enabled": "document" in allowed_types, + "document_max_size_mb": file_config.get("file_size_limit", 100) if file_config else 100, + "document_allowed_extensions": ["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"], + "video_enabled": "video" in allowed_types, + "video_max_size_mb": file_config.get("video_file_size_limit", 100) if file_config else 100, + "video_allowed_extensions": ["mp4", "mov"], + "max_file_count": file_upload.get("number_limits", 1) if file_upload else 1, + "allowed_transfer_methods": transfer_method, + } + + # text_to_speech + tts = features.get("text_to_speech", {}) + result["text_to_speech"] = { + "enabled": tts.get("enabled", False) if isinstance(tts, dict) else False, + "voice": tts.get("voice") if isinstance(tts, dict) else None, + "language": tts.get("language") if isinstance(tts, dict) else None, + "autoplay": False, + } + + # suggested_questions_after_answer + sqa = features.get("suggested_questions_after_answer", {}) + result["suggested_questions_after_answer"] = { + "enabled": sqa.get("enabled", False) if isinstance(sqa, dict) else False, + } + + return result diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index abd95408..c699f877 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -45,6 +45,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "question-classifier": NodeType.QUESTION_CLASSIFIER, "variable-aggregator": NodeType.VAR_AGGREGATOR, "tool": NodeType.TOOL, + "list-operator": NodeType.LIST_OPERATOR, + "document-extractor": NodeType.DOCUMENT_EXTRACTOR, "": NodeType.NOTES } @@ -117,9 +119,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): if variable: self.conv_variables.append(con_var) - # for variables in config.get("workflow").get("environment_variables"): - # variable = self._convert_variable(variables) - # conv_variables.append(variable) + # 开始节点的文件变量合并到会话变量 + self.conv_variables.extend(self._file_vars_to_conv) + + features = self.convert_features( + self.config.get("workflow", {}).get("features", {}) + ) trigger = self._convert_trigger({}) execution_config = self._convert_execution({}) @@ -133,6 +138,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): edges=self.edges, nodes=self.nodes, variables=self.conv_variables, + features=features, warnings=self.warnings, errors=self.errors ) diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py index e96e0bf2..0f44ad72 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py @@ -22,6 +22,8 @@ from app.core.workflow.nodes.configs import ( MemoryReadNodeConfig, MemoryWriteNodeConfig, NoteNodeConfig, + ListOperatorNodeConfig, + DocExtractorNodeConfig, ) from app.core.workflow.nodes.enums import NodeType @@ -51,6 +53,8 @@ class MemoryBearConverter(BaseConverter): NodeType.MEMORY_READ: MemoryReadNodeConfig, NodeType.MEMORY_WRITE: MemoryWriteNodeConfig, NodeType.NOTES: NoteNodeConfig, + NodeType.LIST_OPERATOR: ListOperatorNodeConfig, + NodeType.DOCUMENT_EXTRACTOR: DocExtractorNodeConfig, } @staticmethod diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index daef6e82..e0bdebf3 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -31,9 +31,9 @@ logger = logging.getLogger(__name__) # Example: # "Hello {{user.name}}!" -> # ["Hello ", "{{user.name}}", "!"] -_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+') +_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{]+|{') # Strict variable format: {{ node_id.field_name }} -_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}') +_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}') class GraphBuilder: diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index be0c957a..dc16df17 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -59,6 +59,9 @@ class WorkflowResultBuilder: conversation_vars = variable_pool.get_all_conversation_vars() sys_vars = variable_pool.get_all_system_vars() + # 汇总所有 knowledge 节点的 citations + citations = self.aggregate_citations(node_outputs) + return { "status": "completed" if success else "failed", "output": final_output, @@ -71,9 +74,25 @@ class WorkflowResultBuilder: "conversation_id": execution_context.conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, + "citations": citations, "error": result.get("error"), } + @staticmethod + def aggregate_citations(node_outputs: dict) -> list: + """从所有 knowledge 节点的输出中汇总 citations,去重""" + seen = set() + citations = [] + for node_output in node_outputs.values(): + if not isinstance(node_output, dict): + continue + for c in node_output.get("citations", []): + key = c.get("document_id") + if key and key not in seen: + seen.add(key) + citations.append(c) + return citations + @staticmethod def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None: """ diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index dcc92fdb..361f99d2 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -14,7 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool logger = get_logger(__name__) SCOPE_PATTERN = re.compile( - r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}" + r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)?\s*}}" ) diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 7faca82d..08d10e22 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -34,19 +34,22 @@ class LazyVariableDict: return self._cache[key] var_struct = self._source.get(key) if var_struct is None: - raise KeyError(key) - value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value() + return None + raw = var_struct.instance.get_value() + # literal 模式下 dict/list 保留结构,让 Jinja2 能继续访问子字段(如 .type) + value = raw if (not self._literal or isinstance(raw, (dict, list))) else var_struct.instance.to_literal() self._cache[key] = value return value def get(self, key, default=None): - try: - return self._resolve(key) - except KeyError: - return default + value = self._resolve(key) + return default if value is None else value def __getitem__(self, key): - return self._resolve(key) + value = self._resolve(key) + if value is None: + raise KeyError(key) + return value def __getattr__(self, key): if key.startswith('_'): @@ -164,7 +167,7 @@ class VariablePool: def transform_selector(selector): variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip() selector = VariableSelector.from_string(variable_literal).path - if len(selector) != 2: + if len(selector) not in (2, 3): raise ValueError(f"Selector not valid - {selector}") return selector @@ -196,6 +199,16 @@ class VariablePool: return None return var_instance + @staticmethod + def _extract_field(struct: "VariableStruct", field: str | None) -> Any: + """If field is given, drill into a dict/object variable's value.""" + if field is None: + return struct.instance.get_value() + value = struct.instance.get_value() + if not isinstance(value, dict): + raise KeyError(f"Variable is not an object, cannot access field '{field}'") + return value.get(field) + def get_instance( self, selector: str, @@ -250,12 +263,14 @@ class VariablePool: Raises: KeyError: If strict is True and the variable does not exist. """ + path = self.transform_selector(selector) variable_struct = self._get_variable_struct(selector) if variable_struct is None: if strict: raise KeyError(f"{selector} not exist") return default - + if len(path) == 3: + return self._extract_field(variable_struct, path[2]) return variable_struct.instance.get_value() def get_literal( @@ -282,12 +297,15 @@ class VariablePool: Raises: KeyError: If strict is True and the variable does not exist. """ + path = self.transform_selector(selector) variable_struct = self._get_variable_struct(selector) if variable_struct is None: if strict: raise KeyError(f"{selector} not exist") return default - + if len(path) == 3: + value = self._extract_field(variable_struct, path[2]) + return str(value) if value is not None else "" return variable_struct.instance.to_literal() async def set( @@ -318,7 +336,7 @@ class VariablePool: namespace: str, key: str, value: Any, - var_type: VariableType, + var_type: VariableType | None, mut: bool ): if self.has(f"{namespace}.{key}"): @@ -345,7 +363,14 @@ class VariablePool: Returns: 变量是否存在 """ - return self._get_variable_struct(selector) is not None + path = self.transform_selector(selector) + struct = self._get_variable_struct(selector) + if struct is None: + return False + if len(path) == 3: + value = struct.instance.get_value() + return isinstance(value, dict) and path[2] in value + return True def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict: return LazyVariableDict(self.variables.get(namespace, {}), literal) @@ -493,6 +518,23 @@ class VariablePoolInitializer: var_value = var_default else: var_value = DEFAULT_VALUE(var_type) + # Convert FileInput-format dicts to full FileObject dicts + if var_type == VariableType.FILE: + if not var_value: + continue + var_value = await self._resolve_file_default(var_value) + if not var_value: + continue + elif var_type == VariableType.ARRAY_FILE: + if not var_value: + var_value = [] + else: + resolved = [] + for item in var_value: + f = await self._resolve_file_default(item) + if f: + resolved.append(f) + var_value = resolved await variable_pool.new( namespace="conv", key=var_name, @@ -501,6 +543,17 @@ class VariablePoolInitializer: mut=True ) + @staticmethod + async def _resolve_file_default(file_def: dict) -> dict | None: + """Accept only already-resolved FileObject dicts (is_file=True). + FileInput-format dicts are converted at save time by WorkflowService._resolve_variables_file_defaults. + """ + if not isinstance(file_def, dict): + return None + if file_def.get("is_file"): + return file_def + return None + @staticmethod async def _init_system_vars( variable_pool: VariablePool, diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index bedf6165..5458a80c 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -395,7 +395,8 @@ class BaseNode(ABC): "output": output, "elapsed_time": elapsed_time, "token_usage": token_usage, - "error": None + "error": None, + **self._extract_extra_fields(business_result), } final_output = { "node_outputs": {self.node_id: node_output}, @@ -498,6 +499,13 @@ class BaseNode(ABC): # Default implementation returns the business result directly return business_result + def _extract_extra_fields(self, business_result: Any) -> dict: + """Extracts extra fields to merge into node_output (e.g. citations). + + Subclasses may override to inject additional metadata. + """ + return {} + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """Extracts token usage information from the business result. diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index d89b208b..69c660fe 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -13,7 +13,7 @@ from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes import BaseNode from app.core.workflow.nodes.code.config import CodeNodeConfig -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE logger = logging.getLogger(__name__) @@ -70,7 +70,8 @@ class CodeNode(BaseNode): for output in self.typed_config.output_variables: value = exec_result.get(output.name) if value is None: - raise RuntimeError(f"Return value {output.name} does not exist") + result[output.name] = DEFAULT_VALUE(output.type) + continue match output.type: case VariableType.STRING: if not isinstance(value, str): diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 31dadc38..5ec029cc 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -24,6 +24,8 @@ from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.notes.config import NoteNodeConfig +from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig +from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig __all__ = [ # 基础类 @@ -49,5 +51,7 @@ __all__ = [ "MemoryReadNodeConfig", "MemoryWriteNodeConfig", "CodeNodeConfig", - "NoteNodeConfig" + "NoteNodeConfig", + "ListOperatorNodeConfig", + "DocExtractorNodeConfig", ] diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index cf7ac976..1633b9c7 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -28,86 +28,135 @@ class IterationRuntime: def __init__( self, - start_id: str, stream: bool, - graph: CompiledStateGraph, node_id: str, config: dict[str, Any], state: WorkflowState, variable_pool: VariablePool, - child_variable_pool: VariablePool, + cycle_nodes: list, + cycle_edges: list, ): """ Initialize the iteration runtime. Args: - graph: Compiled workflow graph capable of async invocation. - node_id: Unique identifier of the loop node. - config: Dictionary containing iteration node configuration. - state: Current workflow state at the point of iteration. + stream: Whether to run in streaming mode. When True, each iteration + uses graph.astream and emits cycle_item events in real time. + When False, graph.ainvoke is used instead. + node_id: The unique identifier of the iteration node in the workflow. + Also used as the variable namespace for item/index inside + the subgraph (e.g. {{ node_id.item }}). + config: Raw configuration dict for the iteration node, parsed into + IterationNodeConfig. Controls input/output variable selectors, + parallel execution settings, and output flattening. + state: The parent workflow state at the point the iteration node is + entered. Each task receives a copy of this state as its + starting point. + variable_pool: The parent VariablePool containing all variables available + at the time the iteration node executes, including sys.*, + conv.*, and outputs from upstream nodes. Used as the source + for deep-copying into each task's independent child pool. + cycle_nodes: List of node config dicts belonging to this iteration's + subgraph (i.e. nodes whose cycle field equals node_id). + Passed to GraphBuilder when constructing each task's subgraph. + cycle_edges: List of edge config dicts connecting nodes within the subgraph. + Passed to GraphBuilder alongside cycle_nodes. """ - self.start_id = start_id self.stream = stream - self.graph = graph self.state = state self.node_id = node_id self.typed_config = IterationNodeConfig(**config) self.looping = True self.variable_pool = variable_pool - self.child_variable_pool = child_variable_pool + self.cycle_nodes = cycle_nodes + self.cycle_edges = cycle_edges self.event_write = get_stream_writer() - self.checkpoint = RunnableConfig( - configurable={ - "thread_id": uuid.uuid4() - } - ) self.output_value = None self.result: list = [] - async def _init_iteration_state(self, item, idx): + def _build_child_graph(self) -> tuple[CompiledStateGraph, VariablePool, str]: """ - Initialize a per-iteration copy of the workflow state. + Build an independent compiled subgraph for a single iteration task. - Args: - item: Current element from the input array for this iteration. - idx: Index of the element in the input array. + Each call creates a brand-new VariablePool by deep-copying the parent pool, + then passes it to GraphBuilder. GraphBuilder binds this pool to every node's + execution closure at build time, so the pool and the subgraph always reference + the same object. This is the key design invariant: item/index written into the + pool after build will be visible to all nodes inside the subgraph. Returns: - A copy of the workflow state with iteration-specific variables set. + graph: The compiled LangGraph subgraph ready for invocation. + child_pool: The VariablePool bound to this subgraph's node closures. + Callers must write item/index into this pool before invoking + the graph, and read output from it after invocation. + start_node_id: The ID of the CYCLE_START node inside the subgraph, + used to set the initial activation signal in workflow state. """ - loopstate = WorkflowState( - **self.state + from app.core.workflow.engine.graph_builder import GraphBuilder + child_pool = VariablePool() + child_pool.copy(self.variable_pool) + builder = GraphBuilder( + {"nodes": self.cycle_nodes, "edges": self.cycle_edges}, + stream=self.stream, + variable_pool=child_pool, + cycle=self.node_id, ) - self.child_variable_pool.copy(self.variable_pool) - await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) - await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True) - loopstate["node_outputs"][self.node_id] = { - "item": item, - "index": idx, - } + graph = builder.build() + return graph, builder.variable_pool, builder.start_node_id + + async def _init_iteration_state(self, item, idx, child_pool: VariablePool, start_id: str): + """ + Initialize the workflow state for a single iteration. + + Writes the current item and its index into child_pool under the iteration + node's namespace (e.g. iteration_xxx.item, iteration_xxx.index), making them + accessible to downstream nodes inside the subgraph via variable selectors. + + Also prepares a copy of the parent workflow state with: + - node_outputs[node_id] set to {item, index} so the state snapshot is consistent + with the pool values. + - looping flag set to 1 (active) to signal the subgraph is inside a cycle. + - activate[start_id] set to True to trigger the CYCLE_START node. + + Args: + item: The current element from the input array. + idx: The zero-based index of this element in the input array. + child_pool: The VariablePool bound to this iteration's subgraph. + Must be the same object returned by _build_child_graph. + start_id: The ID of the CYCLE_START node inside the subgraph. + + Returns: + A WorkflowState instance ready to be passed to graph.ainvoke or graph.astream. + """ + loopstate = WorkflowState(**self.state) + await child_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) + await child_pool.new(self.node_id, "index", idx, VariableType.type_map(idx), mut=True) + loopstate["node_outputs"][self.node_id] = {"item": item, "index": idx} loopstate["looping"] = 1 - loopstate["activate"][self.start_id] = True + loopstate["activate"][start_id] = True return loopstate - def merge_conv_vars(self): - self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables["conv"] - ) + def _merge_conv_vars(self, child_pool: VariablePool): + self.variable_pool.variables["conv"].update(child_pool.variables["conv"]) async def run_task(self, item, idx): """ Execute a single iteration asynchronously. + Each task builds its own subgraph so the variable pool closure is independent. - Args: - item: The input element for this iteration. - idx: The index of this iteration. + Returns: + Tuple of (idx, output, result, child_pool, stopped) """ + graph, child_pool, start_id = self._build_child_graph() + checkpoint = RunnableConfig(configurable={"thread_id": uuid.uuid4()}) + init_state = await self._init_iteration_state(item, idx, child_pool, start_id) + if self.stream: - async for event in self.graph.astream( - await self._init_iteration_state(item, idx), + async for event in graph.astream( + init_state, stream_mode=["debug"], - config=self.checkpoint + config=checkpoint ): if isinstance(event, tuple) and len(event) == 2: mode, data = event @@ -117,7 +166,6 @@ class IterationRuntime: event_type = data.get("type") payload = data.get("payload", {}) node_name = payload.get("name") - if node_name and node_name.startswith("nop"): continue if event_type == "task_result": @@ -140,17 +188,13 @@ class IterationRuntime: "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") } }) - result = self.graph.get_state(config=self.checkpoint).values + result = graph.get_state(config=checkpoint).values else: - result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) - output = self.child_variable_pool.get_value(self.output_value) - if isinstance(output, list) and self.typed_config.flatten: - self.result.extend(output) - else: - self.result.append(output) - if result["looping"] == 2: - self.looping = False - return result + result = await graph.ainvoke(init_state) + + output = child_pool.get_value(self.output_value) + stopped = result["looping"] == 2 + return idx, output, result, child_pool, stopped def _create_iteration_tasks(self, array_obj, idx): """ @@ -196,16 +240,32 @@ class IterationRuntime: tasks = self._create_iteration_tasks(array_obj, idx) logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count - child_state.extend(await asyncio.gather(*tasks)) - self.merge_conv_vars() + batch = await asyncio.gather(*tasks) + # Sort by idx to preserve order, then collect results + batch_sorted = sorted(batch, key=lambda x: x[0]) + for _, output, result, child_pool, stopped in batch_sorted: + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + child_state.append(result) + self._merge_conv_vars(child_pool) + if stopped: + self.looping = False else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] - result = await self.run_task(item, idx) - self.merge_conv_vars() + _, output, result, child_pool, stopped = await self.run_task(item, idx) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + self._merge_conv_vars(child_pool) child_state.append(result) + if stopped: + self.looping = False idx += 1 logger.info(f"Iteration node {self.node_id}: execution completed") return { diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index fc80939f..002c34df 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -55,9 +55,9 @@ class CycleGraphNode(BaseNode): if config.output_type in [ VariableType.ARRAY_FILE, VariableType.ARRAY_STRING, - VariableType.NUMBER, + VariableType.ARRAY_NUMBER, VariableType.ARRAY_OBJECT, - VariableType.BOOLEAN + VariableType.ARRAY_BOOLEAN ]: if config.flatten: outputs['output'] = config.output_type @@ -123,7 +123,7 @@ class CycleGraphNode(BaseNode): return cycle_nodes, cycle_edges - def build_graph(self): + def build_graph(self, variable_pool: VariablePool): """ Build and compile the internal subgraph for this cycle node. @@ -135,6 +135,7 @@ class CycleGraphNode(BaseNode): from app.core.workflow.engine.graph_builder import GraphBuilder self.child_variable_pool = VariablePool() + self.child_variable_pool.copy(variable_pool) builder = GraphBuilder( { "nodes": self.cycle_nodes, @@ -165,8 +166,8 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ - self.build_graph() if self.node_type == NodeType.LOOP: + self.build_graph(variable_pool) return await LoopRuntime( start_id=self.start_node_id, stream=False, @@ -179,20 +180,19 @@ class CycleGraphNode(BaseNode): ).run() if self.node_type == NodeType.ITERATION: return await IterationRuntime( - start_id=self.start_node_id, stream=False, - graph=self.graph, node_id=self.node_id, config=self.config, state=state, variable_pool=variable_pool, - child_variable_pool=self.child_variable_pool + cycle_nodes=self.cycle_nodes, + cycle_edges=self.cycle_edges, ).run() raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): - self.build_graph() if self.node_type == NodeType.LOOP: + self.build_graph(variable_pool) yield { "__final__": True, "result": await LoopRuntime( @@ -211,14 +211,13 @@ class CycleGraphNode(BaseNode): yield { "__final__": True, "result": await IterationRuntime( - start_id=self.start_node_id, stream=True, - graph=self.graph, node_id=self.node_id, config=self.config, state=state, variable_pool=variable_pool, - child_variable_pool=self.child_variable_pool + cycle_nodes=self.cycle_nodes, + cycle_edges=self.cycle_edges, ).run() } return diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index bd828760..cada495c 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -14,12 +14,22 @@ logger = logging.getLogger(__name__) def _file_object_to_file_input(f: FileObject) -> FileInput: """Convert workflow FileObject to multimodal FileInput.""" + file_type = f.origin_file_type or "" + # Prefer mime_type for more accurate type detection + if not file_type and f.mime_type: + file_type = f.mime_type + resolved_type = FileType.trans(f.type) if isinstance(f.type, str) else f.type + if resolved_type != FileType.DOCUMENT: + raise ValueError( + f"Document extractor only supports document files, got type '{f.type}' " + f"(name={f.name or f.file_id or f.url})" + ) return FileInput( - type=FileType.DOCUMENT, + type=resolved_type, transfer_method=TransferMethod(f.transfer_method), url=f.url or None, upload_file_id=f.file_id or None, - file_type=f.origin_file_type or "", + file_type=file_type, ) @@ -81,6 +91,7 @@ class DocExtractorNode(BaseNode): from app.services.multimodal_service import MultimodalService svc = MultimodalService(db) for f in files: + label = f.name or f.url or f.file_id try: file_input = _file_object_to_file_input(f) # Ensure URL is populated for local files @@ -93,7 +104,7 @@ class DocExtractorNode(BaseNode): chunks.append(text) except Exception as e: logger.error( - f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}", + f"Node {self.node_id}: failed to extract file '{label}': {e}", exc_info=True, ) chunks.append("") diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 529cd0b3..bd0d8426 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -24,6 +24,7 @@ class NodeType(StrEnum): MEMORY_READ = "memory-read" MEMORY_WRITE = "memory-write" DOCUMENT_EXTRACTOR = "document-extractor" + LIST_OPERATOR = "list-operator" UNKNOWN = "unknown" NOTES = "notes" @@ -45,6 +46,8 @@ class ComparisonOperator(StrEnum): LE = "le" GT = "gt" GE = "ge" + IN = "in" + NOT_IN = "not_in" class LogicOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index e1b84f0c..72474436 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -72,8 +72,9 @@ class HttpContentTypeConfig(BaseModel): @classmethod def validate_data(cls, v, info): content_type = info.data.get("content_type") - if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData): - raise ValueError("When content_type is 'form-data', data must be of type HttpFormData") + if content_type == HttpContentType.FROM_DATA and ( + not isinstance(v, list) or not all(isinstance(item, HttpFormData) for item in v)): + raise ValueError("When content_type is 'form-data', data must be a list of HttpFormData") elif content_type in [HttpContentType.JSON] and not isinstance(v, str): raise ValueError("When content_type is JSON, data must be of type str") elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict): diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 086bee4a..783c230b 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -260,17 +260,22 @@ class HttpRequestNode(BaseNode): )) case HttpContentType.FROM_DATA: data = {} - content["files"] = {} + files = [] for item in self.typed_config.body.data: + key = self._render_template(item.key, variable_pool) if item.type == "text": - data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, - variable_pool) + data[key] = self._render_template(item.value, variable_pool) elif item.type == "file": - content["files"][self._render_template(item.key, variable_pool)] = ( - uuid.uuid4().hex, - await variable_pool.get_instance(item.value).get_content() - ) + file_instance = variable_pool.get_instance(item.value) + if isinstance(file_instance, ArrayVariable): + for v in file_instance.value: + if isinstance(v, FileVariable): + files.append((key, (uuid.uuid4().hex, await v.get_content()))) + elif isinstance(file_instance, FileVariable): + files.append((key, (uuid.uuid4().hex, await file_instance.get_content()))) content["data"] = data + if files: + content["files"] = files case HttpContentType.BINARY: content["files"] = [] file_instence = variable_pool.get_instance(self.typed_config.body.data) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index d0b6d098..2a8c5249 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -8,6 +8,8 @@ from langchain_core.documents import Document from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.models.chunk import DocumentChunk from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState @@ -34,6 +36,21 @@ class KnowledgeRetrievalNode(BaseNode): "output": VariableType.ARRAY_STRING } + def _extract_output(self, business_result: Any) -> Any: + """下游节点只拿 chunks 列表""" + if isinstance(business_result, dict) and "chunks" in business_result: + return business_result["chunks"] + return business_result + + @staticmethod + def _extract_citations(business_result: Any) -> list: + if isinstance(business_result, dict): + return business_result.get("citations", []) + return [] + + def _extract_extra_fields(self, business_result: Any) -> dict: + return {"citations": self._extract_citations(business_result)} + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: return { "query": self._render_template(self.typed_config.query, variable_pool), @@ -216,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode): } ) ) - case RetrieveType.HYBRID: + case retrieve_type if retrieve_type in (RetrieveType.HYBRID, RetrieveType.Graph): rs1_task = asyncio.to_thread( - vector_service.search_by_vector, **{ - "query": query, - "top_k": kb_config.top_k, - "indices": indices, - "score_threshold": kb_config.vector_similarity_weight - } - ) + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) rs2_task = asyncio.to_thread( - vector_service.search_by_full_text, **{ - "query": query, - "top_k": kb_config.top_k, - "indices": indices, - "score_threshold": kb_config.similarity_threshold - } - ) + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results @@ -252,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode): key=lambda d: d.metadata.get("score", 0), reverse=True )[:kb_config.top_k]) + if kb_config.retrieve_type == RetrieveType.Graph: + from app.core.rag.common.settings import kg_retriever + llm_key = self.model_balance(db_knowledge.llm) + emb_key = self.model_balance(db_knowledge.embedding) + chat_model = Base( + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base + ) + embedding_model = OpenAIEmbed( + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base + ) + doc = await asyncio.to_thread( + kg_retriever.retrieval, + question=query, + workspace_ids=[str(db_knowledge.workspace_id)], + kb_ids=[str(kb_config.kb_id)], + emb_mdl=embedding_model, + llm=chat_model + ) + if doc: + rs.insert(0, DocumentChunk( + page_content=doc.get("page_content", ""), + metadata=doc.get("metadata", {}) + )) case _: raise RuntimeError("Unknown retrieval type") return rs @@ -314,4 +358,20 @@ class KnowledgeRetrievalNode(BaseNode): logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) - return [chunk.page_content for chunk in final_rs] + citations = [] + seen_doc_ids = set() + for chunk in final_rs: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations.append({ + "document_id": str(doc_id), + "file_name": meta.get("file_name", ""), + "knowledge_id": str(meta.get("knowledge_id", kb_config.kb_id)), + "score": meta.get("score", 0.0), + }) + return { + "chunks": [chunk.page_content for chunk in final_rs], + "citations": citations, + } diff --git a/api/app/core/workflow/nodes/list_operator/__init__.py b/api/app/core/workflow/nodes/list_operator/__init__.py new file mode 100644 index 00000000..1877586e --- /dev/null +++ b/api/app/core/workflow/nodes/list_operator/__init__.py @@ -0,0 +1,3 @@ +from .node import ListOperatorNode + +__all__ = ["ListOperatorNode"] diff --git a/api/app/core/workflow/nodes/list_operator/config.py b/api/app/core/workflow/nodes/list_operator/config.py new file mode 100644 index 00000000..6fde6a57 --- /dev/null +++ b/api/app/core/workflow/nodes/list_operator/config.py @@ -0,0 +1,49 @@ +from typing import Any +from pydantic import BaseModel, Field, field_validator + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.nodes.enums import ComparisonOperator + + +class FilterCondition(BaseModel): + key: str = "" + comparison_operator: ComparisonOperator = ComparisonOperator.CONTAINS + value: str | list[str] | bool = "" + + +class FilterBy(BaseModel): + enabled: bool = False + conditions: list[FilterCondition] = Field(default_factory=list) + + +class OrderByConfig(BaseModel): + enabled: bool = False + key: str = "" + value: str = "asc" # "asc" | "desc" + + +class Limit(BaseModel): + enabled: bool = False + size: int = -1 + + +class ExtractConfig(BaseModel): + enabled: bool = False + serial: str = "1" # 1-based index string, e.g. "1" = first + + @field_validator("serial", mode="before") + @classmethod + def coerce_serial(cls, v): + return str(v) + + +class ListOperatorNodeConfig(BaseNodeConfig): + """ + List Operator node config. + Operation order: filter -> extract -> order -> limit + """ + input_list: str = Field(..., description="Variable selector, e.g. {{ sys.files }} or {{ conv.uploaded_files }}") + filter_by: FilterBy = Field(default_factory=FilterBy) + order_by: OrderByConfig = Field(default_factory=OrderByConfig) + limit: Limit = Field(default_factory=Limit) + extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/app/core/workflow/nodes/list_operator/node.py b/api/app/core/workflow/nodes/list_operator/node.py new file mode 100644 index 00000000..edc15ed1 --- /dev/null +++ b/api/app/core/workflow/nodes/list_operator/node.py @@ -0,0 +1,150 @@ +import logging +from typing import Any + +from app.core.workflow.engine.state_manager import WorkflowState +from app.core.workflow.engine.variable_pool import VariablePool +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.enums import ComparisonOperator +from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig, FilterCondition +from app.core.workflow.variable.base_variable import VariableType + +logger = logging.getLogger(__name__) + +# File object fields that hold string values +_FILE_STRING_KEYS = {"type", "name", "url", "extension", "mime_type", "transfer_method", "origin_file_type", "file_id"} +_FILE_NUMBER_KEYS = {"size"} + + +class ListOperatorNode(BaseNode): + def __init__(self, node_config: dict, workflow_config: dict, down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) + self.typed_config: ListOperatorNodeConfig | None = None + + def _output_types(self) -> dict[str, VariableType]: + return { + "result": VariableType.ANY, + "first_record": VariableType.ANY, + "last_record": VariableType.ANY, + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: + self.typed_config = ListOperatorNodeConfig(**self.config) + cfg = self.typed_config + + # Resolve input variable from path selector + items: list = self.get_variable(cfg.input_list, variable_pool) + if not isinstance(items, list): + raise TypeError(f"Variable '{cfg.input_list}' must be an array, got {type(items)}") + + result = list(items) + + # 1. Filter + if cfg.filter_by.enabled and cfg.filter_by.conditions: + for condition in cfg.filter_by.conditions: + result = [item for item in result if self._match_condition(item, condition, variable_pool)] + + # 2. Extract (take single item by 1-based serial index) + if cfg.extract_by.enabled: + serial_str = self._resolve_value(cfg.extract_by.serial, variable_pool) + idx = int(serial_str) - 1 + if idx < 0 or idx >= len(result): + raise ValueError(f"extract_by.serial={cfg.extract_by.serial} out of range (list length={len(result)})") + result = [result[idx]] + + # 3. Order + if cfg.order_by.enabled: + reverse = cfg.order_by.value == "desc" + key_fn = self._make_sort_key(cfg.order_by.key) + result = sorted(result, key=key_fn, reverse=reverse) + + # 4. Limit (take first N) + if cfg.limit.enabled and cfg.limit.size > 0: + result = result[:cfg.limit.size] + + return { + "result": result, + "first_record": result[0] if result else None, + "last_record": result[-1] if result else None, + } + + @staticmethod + def _resolve_value(value: str, variable_pool: VariablePool) -> Any: + """If value is a {{ namespace.key }} variable selector, resolve it from the pool. + Otherwise return the raw string.""" + import re + m = re.fullmatch(r"\{\{\s*(\w+\.\w+)\s*}}", value.strip()) + if m: + resolved = variable_pool.get_value(value, default=value, strict=False) + return resolved + return value + + @staticmethod + def _make_sort_key(key: str): + def key_fn(item): + if isinstance(item, dict): + return item.get(key) or "" + return item + return key_fn + + def _match_condition(self, item: Any, cond: FilterCondition, variable_pool: VariablePool) -> bool: + op = cond.comparison_operator + value = cond.value + + # Resolve value if it's a variable reference {{ namespace.key }} + if isinstance(value, str): + value = self._resolve_value(value, variable_pool) + + # Resolve left value + if isinstance(item, dict): + left = item.get(cond.key) + else: + left = item # primitive array: compare element directly + + # Determine if this field should be compared as a string + is_string_field = isinstance(item, dict) and cond.key in _FILE_STRING_KEYS + + # Numeric operators + if op == ComparisonOperator.EQ: + if is_string_field: + return str(left) == str(value) + return self._safe_num(left) == self._safe_num(value) + if op == ComparisonOperator.NE: + if is_string_field: + return str(left) != str(value) + return self._safe_num(left) != self._safe_num(value) + if op == ComparisonOperator.LT: + return self._safe_num(left) < self._safe_num(value) + if op == ComparisonOperator.LE: + return self._safe_num(left) <= self._safe_num(value) + if op == ComparisonOperator.GT: + return self._safe_num(left) > self._safe_num(value) + if op == ComparisonOperator.GE: + return self._safe_num(left) >= self._safe_num(value) + + # String / sequence operators + left_str = str(left) if left is not None else "" + if op == ComparisonOperator.CONTAINS: + return str(value) in left_str + if op == ComparisonOperator.NOT_CONTAINS: + return str(value) not in left_str + if op == ComparisonOperator.START_WITH: + return left_str.startswith(str(value)) + if op == ComparisonOperator.END_WITH: + return left_str.endswith(str(value)) + if op == ComparisonOperator.IN: + return left_str in (value if isinstance(value, list) else [str(value)]) + if op == ComparisonOperator.NOT_IN: + return left_str not in (value if isinstance(value, list) else [str(value)]) + if op == ComparisonOperator.EMPTY: + return not left + if op == ComparisonOperator.NOT_EMPTY: + return bool(left) + + raise ValueError(f"Unsupported operator: {op}") + + @staticmethod + def _safe_num(v) -> float: + try: + return float(v) + except (TypeError, ValueError): + return 0.0 diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 771262c1..b815c80f 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -116,6 +116,11 @@ class LLMNodeConfig(BaseNodeConfig): description="Top-p 采样参数" ) + json_output: bool = Field( + default=False, + description="是否以 JSON 格式输出" + ) + frequency_penalty: float | None = Field( default=None, ge=-2.0, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index a691001f..664a28fa 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -22,6 +22,7 @@ from app.db import get_db_context from app.models import ModelType from app.schemas.model_schema import ModelInfo from app.services.model_service import ModelConfigService +from app.models.models_model import ModelProvider logger = logging.getLogger(__name__) @@ -126,7 +127,11 @@ class LLMNode(BaseNode): # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True - extra_params = {"streaming": stream} if stream else {} + extra_params: dict[str, Any] = {"streaming": stream} if stream else {} + if self.typed_config.temperature is not None: + extra_params["temperature"] = self.typed_config.temperature + if self.typed_config.max_tokens is not None: + extra_params["max_tokens"] = self.typed_config.max_tokens llm = RedBearLLM( RedBearModelConfig( @@ -135,7 +140,9 @@ class LLMNode(BaseNode): api_key=model_info.api_key, base_url=model_info.api_base, extra_params=extra_params, - is_omni=model_info.is_omni + is_omni=model_info.is_omni, + capability=model_info.capability, + json_output=self.typed_config.json_output, ), type=model_info.model_type ) @@ -213,9 +220,20 @@ class LLMNode(BaseNode): messages = messages[:-1] + history_message + messages[-1:] self.messages = messages else: - # 使用简单的 prompt 格式(向后兼容) + # 使用简单的 prompt 格式(向后兼容)——包装为标准消息列表以兼容所有 provider prompt_template = self.config.get("prompt", "") - self.messages = self._render_template(prompt_template, variable_pool) + rendered = self._render_template(prompt_template, variable_pool) + self.messages = [{"role": "user", "content": rendered}] + + # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入 + if (self.typed_config.json_output + and model_info.provider.lower() == ModelProvider.DASHSCOPE + and not model_info.is_omni): + system_msg = next((m for m in self.messages if m["role"] == "system"), None) + if system_msg: + system_msg["content"] += "\n请以JSON格式输出。" + else: + self.messages.insert(0, {"role": "system", "content": "请以JSON格式输出。"}) return llm @@ -245,7 +263,10 @@ class LLMNode(BaseNode): logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") # 返回 AIMessage(包含响应元数据) - return AIMessage(content=content, response_metadata=response.response_metadata) + return AIMessage(content=content, response_metadata={ + **response.response_metadata, + "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') + }) def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" @@ -304,15 +325,16 @@ class LLMNode(BaseNode): # 调用 LLM(流式,支持字符串或消息列表) last_meta_data = {} + last_usage_metadata = {} async for chunk in llm.astream(self.messages): - # 提取内容 if hasattr(chunk, 'content'): content = self.process_model_output(chunk.content) else: content = str(chunk) - if hasattr(chunk, 'response_metadata'): - if chunk.response_metadata: - last_meta_data = chunk.response_metadata + if hasattr(chunk, 'response_metadata') and chunk.response_metadata: + last_meta_data = chunk.response_metadata + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + last_usage_metadata = chunk.usage_metadata # 只有当内容不为空时才处理 if content: @@ -335,7 +357,10 @@ class LLMNode(BaseNode): # 构建完整的 AIMessage(包含元数据) final_message = AIMessage( content=full_response, - response_metadata=last_meta_data + response_metadata={ + **last_meta_data, + "token_usage": last_usage_metadata or last_meta_data.get('token_usage') + } ) # yield 完成标记 diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 49add867..1dfcce74 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -27,6 +27,7 @@ from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.tool import ToolNode from app.core.workflow.nodes.document_extractor import DocExtractorNode +from app.core.workflow.nodes.list_operator import ListOperatorNode logger = logging.getLogger(__name__) @@ -51,7 +52,8 @@ WorkflowNode = Union[ MemoryReadNode, MemoryWriteNode, CodeNode, - DocExtractorNode + DocExtractorNode, + ListOperatorNode ] @@ -83,7 +85,8 @@ class NodeFactory: NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.CODE: CodeNode, - NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode + NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode, + NodeType.LIST_OPERATOR: ListOperatorNode } @classmethod diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 3dc5fcc3..901eddcf 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -12,7 +12,7 @@ from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService @@ -45,6 +45,12 @@ class ParameterExtractorNode(BaseNode): "model_id": str(self.typed_config.model_id), } + def _extract_output(self, business_result: Any) -> Any: + final_output = {} + for param in self.typed_config.params: + final_output[param.name] = business_result.get(param.name) or DEFAULT_VALUE(self.output_types[param.name]) + return final_output + def _output_types(self) -> dict[str, VariableType]: outputs = {} for param in self.typed_config.params: @@ -109,6 +115,7 @@ class ParameterExtractorNode(BaseNode): api_key = api_config.api_key api_base = api_config.api_base is_omni = api_config.is_omni + capability = api_config.capability model_type = config.type llm = RedBearLLM( @@ -201,7 +208,10 @@ class ParameterExtractorNode(BaseNode): ]) model_resp = await llm.ainvoke(messages) - self.response_metadata = model_resp.response_metadata + self.response_metadata = { + **model_resp.response_metadata, + "token_usage": getattr(model_resp, 'usage_metadata', None) or model_resp.response_metadata.get('token_usage') + } model_message = self.process_model_output(model_resp.content) result = json_repair.repair_json(model_message, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 31fadaf6..74ff1cf9 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -62,6 +62,7 @@ class QuestionClassifierNode(BaseNode): api_key = api_config.api_key base_url = api_config.api_base is_omni = api_config.is_omni + capability = api_config.capability model_type = config.type return RedBearLLM( @@ -135,7 +136,10 @@ class QuestionClassifierNode(BaseNode): response = await llm.ainvoke(messages) result = self.process_model_output(response.content) - self.response_metadata = response.response_metadata + self.response_metadata = { + **response.response_metadata, + "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') + } if result in category_names: category = result diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 72c5c6a8..410f64c3 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -15,6 +15,7 @@ from app.services.tool_service import ToolService logger = logging.getLogger(__name__) TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}") +PURE_VARIABLE_PATTERN = re.compile(r"^\{\{\s*([\w.]+)\s*}}$") class ToolNode(BaseNode): @@ -52,13 +53,21 @@ class ToolNode(BaseNode): # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): - try: - rendered_value = self._render_template(param_template, variable_pool) - except Exception as e: - raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + if isinstance(param_template, str): + pure_match = PURE_VARIABLE_PATTERN.match(param_template) + if pure_match: + # 纯单变量引用直接取原始值,保留 int/bool/float 等类型 + rendered_value = self.get_variable(pure_match.group(1), variable_pool, strict=False) + if rendered_value is None: + rendered_value = self._render_template(param_template, variable_pool) + elif TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, variable_pool) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + rendered_value = param_template else: - # 非模板参数(数字/布尔/普通字符串)直接保留原值 rendered_value = param_template rendered_parameters[param_name] = rendered_value diff --git a/api/app/core/workflow/utils/file_processor.py b/api/app/core/workflow/utils/file_processor.py index ae406ab0..0bedf9a7 100644 --- a/api/app/core/workflow/utils/file_processor.py +++ b/api/app/core/workflow/utils/file_processor.py @@ -1,7 +1,10 @@ # -*- coding: UTF-8 -*- -# Author: Eternity -# @Email: 1533512157@qq.com -# @Time : 2026/3/10 13:36 +import mimetypes +import os +import uuid +from typing import Any +from urllib.parse import urlparse, unquote + TRANSFORM_FILE_TYPE = { 'text/plain': 'document/text', 'text/markdown': 'document/markdown', @@ -52,5 +55,143 @@ ALLOWED_FILE_TYPES = [ def mime_to_file_type(mime_type): if mime_type not in ALLOWED_FILE_TYPES: return None - return TRANSFORM_FILE_TYPE.get(mime_type, mime_type) + + +def build_file_object_dict_from_url(url: str, file_type: str, origin_file_type: str) -> dict[str, Any]: + """Build a FileObject dict for a remote_url file using only URL parsing (no HTTP request). + Used as fallback when HTTP request fails. + """ + raw_path = url.split("?")[0] + name = unquote(os.path.basename(urlparse(url).path)) or None + _, ext = os.path.splitext(name or "") + extension = ext.lstrip(".").lower() if ext else None + guessed_mime = mimetypes.guess_type(url)[0] + return { + "type": file_type, + "url": url, + "transfer_method": "remote_url", + "origin_file_type": origin_file_type, + "file_id": None, + "name": name, + "size": None, + "extension": extension, + "mime_type": guessed_mime or origin_file_type, + "is_file": True, + } + + +async def fetch_remote_file_meta( + url: str, + file_type: str, + origin_file_type: str, +) -> dict[str, Any]: + """Fetch remote file metadata via HEAD (fallback GET) and build a FileObject dict. + Falls back to URL-only parsing if the HTTP request fails. + """ + import httpx + + name = extension = None + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.head(url, follow_redirects=True) + if resp.status_code != 200: + resp = await client.get(url, follow_redirects=True) + + cl = resp.headers.get("Content-Length") + size = int(cl) if cl else None + + ct = resp.headers.get("Content-Type", "").split(";")[0].strip() + mime_type = ct or origin_file_type + + cd = resp.headers.get("Content-Disposition", "") + if "filename=" in cd: + name = cd.split("filename=")[-1].strip('"').strip("'") + if not name: + name = unquote(os.path.basename(urlparse(url).path)) or None + + if name: + _, ext = os.path.splitext(name) + extension = ext.lstrip(".").lower() if ext else None + if not extension and mime_type: + ext = mimetypes.guess_extension(mime_type) + extension = ext.lstrip(".").lower() if ext else None + except Exception: + return build_file_object_dict_from_url(url, file_type, origin_file_type) + + return build_file_object_dict_from_meta( + file_type=file_type, + transfer_method="remote_url", + origin_file_type=origin_file_type, + file_id=None, + url=url, + file_name=name, + file_size=size, + file_ext=extension, + content_type=mime_type, + ) + + +def build_file_object_dict_from_meta( + file_type: str, + transfer_method: str, + origin_file_type: str, + file_id: str, + url: str, + file_name: str | None, + file_size: int | None, + file_ext: str | None, + content_type: str | None, +) -> dict[str, Any]: + """Build a FileObject dict from already-fetched FileMetadata fields.""" + ext = (file_ext or "").lstrip(".") + return { + "type": file_type, + "url": url, + "transfer_method": transfer_method, + "origin_file_type": content_type or origin_file_type, + "file_id": file_id, + "name": file_name, + "size": file_size, + "extension": ext.lower() if ext else None, + "mime_type": content_type, + "is_file": True, + } + + +def resolve_local_file_object_dict( + db, + upload_file_id: str | uuid.UUID, + file_type: str, + origin_file_type: str, +) -> dict[str, Any] | None: + """Query FileMetadata and build a FileObject dict for a local_file. + Returns None if the file is not found or not completed. + """ + from app.models.file_metadata_model import FileMetadata + from app.core.config import settings + + try: + fid = uuid.UUID(str(upload_file_id)) + except ValueError: + return None + + meta = db.query(FileMetadata).filter( + FileMetadata.id == fid, + FileMetadata.status == "completed" + ).first() + if not meta: + return None + + url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{fid}" + return build_file_object_dict_from_meta( + file_type=file_type, + transfer_method="local_file", + origin_file_type=origin_file_type, + file_id=str(fid), + url=url, + file_name=meta.file_name, + file_size=meta.file_size, + file_ext=meta.file_ext, + content_type=meta.content_type, + ) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 0ad74865..7aa107cf 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -301,7 +301,7 @@ class WorkflowValidator: for node in nodes: if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"): errors.append( - f"节点 {node.get('id')} 缺少名称(发布时必须提供)" + f"节点 {node.get('name')} 缺少名称(发布时必须提供)" ) # 2. 验证所有非 start/end 节点都有配置 @@ -311,7 +311,7 @@ class WorkflowValidator: config = node.get("config") if not config or not isinstance(config, dict): errors.append( - f"节点 {node.get('id')} 缺少配置(发布时必须提供)" + f"节点 {node.get('name')} 缺少配置(发布时必须提供)" ) # 3. 验证必填变量 diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index f5d8ff8f..4f034641 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -91,7 +91,7 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: case VariableType.OBJECT: return {} case VariableType.FILE: - return None + return {} case VariableType.ARRAY_STRING: return [] case VariableType.ARRAY_NUMBER: @@ -113,6 +113,12 @@ class FileObject(BaseModel): origin_file_type: str file_id: str | None + # Extended file metadata + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + content_cache: dict = Field(default_factory=dict) is_file: bool diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 79e023c1..2b849c94 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -66,20 +66,10 @@ class FileVariable(BaseVariable): type = 'file' def valid_value(self, value) -> FileObject: - if isinstance(value, dict): if not value.get("is_file"): raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") - return FileObject( - **{ - "type": str(value.get('type')), - "transfer_method": value.get("transfer_method"), - "url": value.get('url'), - "file_id": value.get("file_id"), - "origin_file_type": value.get("origin_file_type"), - "is_file": True - } - ) + return FileObject(**value) if isinstance(value, FileObject): return value raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") @@ -88,13 +78,13 @@ class FileVariable(BaseVariable): return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})' def get_value(self) -> Any: - return self.value.model_dump() + return self.value.model_dump(exclude={"content_cache"}) async def get_content(self): total_bytes = 0 chunks = [] - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: async with client.stream("GET", self.value.url) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(8192): @@ -186,6 +176,8 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T: return BooleanVariable(value) case VariableType.OBJECT: return DictVariable(value) + case VariableType.FILE: + return FileVariable(value) case VariableType.ARRAY_STRING: return make_array(StringVariable, value) case VariableType.ARRAY_NUMBER: diff --git a/api/app/main.py b/api/app/main.py index 9e501f11..a8223a49 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -62,6 +62,7 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") await create_all_indexes() + logger.info("All neo4j indexes and constraints created successfully!") logger.info("应用程序启动完成") diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 69bedc3d..fab85ea6 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -81,7 +81,7 @@ class ModelConfig(BaseModel): # 模型配置参数 capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), - comment="模型能力列表(如['vision', 'audio', 'video'])") + comment="模型能力列表(如['vision', 'audio', 'video', 'thinking'])") is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") config = Column(JSON, comment="模型配置参数") # - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。 diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index a92b5629..c3fd82df 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -29,11 +29,8 @@ class Tenants(Base): contact_email = Column(String(255), nullable=True) # 联系人邮箱 contact_phone = Column(String(50), nullable=True) # 联系人电话 - # 租户套餐信息 - plan = Column(String(50), nullable=True) # 套餐类型 - plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间 - api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 - status = Column(String(50), nullable=True, default='active') # 租户状态 + # 租户套餐信息(只读,从 tenant_subscriptions 动态获取) + status = Column(String(50), nullable=True, default='active', server_default='active') # 租户状态 # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/repositories/app_repository.py b/api/app/repositories/app_repository.py index 75a91fd6..c9d980e2 100644 --- a/api/app/repositories/app_repository.py +++ b/api/app/repositories/app_repository.py @@ -61,3 +61,15 @@ def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App: """根据工作空间ID查询应用""" repo = AppRepository(db) return repo.get_apps_by_id(app_id) + + +def get_release_by_id(db: Session, app_id: uuid.UUID, release_id: uuid.UUID): + """根据发布版本ID查询发布快照(仅返回激活状态)""" + from app.models.app_release_model import AppRelease + return db.scalars( + select(AppRelease).where( + AppRelease.app_id == app_id, + AppRelease.id == release_id, + AppRelease.is_active.is_(True), + ) + ).first() diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py index 6d74bcaf..d4eaddeb 100644 --- a/api/app/repositories/home_page_repository.py +++ b/api/app/repositories/home_page_repository.py @@ -1,6 +1,6 @@ -from datetime import datetime, timedelta +from datetime import datetime, time from sqlalchemy.orm import Session -from sqlalchemy import func +from sqlalchemy import func, Table, MetaData from uuid import UUID from typing import Dict, Optional, Any @@ -192,10 +192,63 @@ class HomePageRepository: return workspaces, app_count_dict, user_count_dict + @staticmethod + def get_latest_version_introduction(db: Session) -> tuple[Optional[str], Optional[Dict[str, Any]]]: + """ + 从数据库获取最新已发布的版本说明 + 使用反射方式读取表结构,不依赖 premium 模型类 + + Args: + db: 数据库会话 + + Returns: + (版本号,版本说明字典) 的元组 + 如果数据库中没有已发布的版本,返回 (None, None) + """ + try: + metadata = MetaData() + + version_notes = Table('version_notes', metadata, autoload_with=db.bind) + + # 获取最新已发布的版本(按发布时间倒序,日期相同时按版本号倒序) + query = db.query(version_notes).filter( + version_notes.c.is_published == True + ).order_by( + version_notes.c.release_date.desc(), + version_notes.c.version.desc() + ) + + note = query.first() + + if not note: + return None, None + + version_info = { + "introduction": { + "codeName": note.code_name or "", + "releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0, + "upgradePosition": note.upgrade_position or "", + "coreUpgrades": note.core_upgrades or [] + }, + "introduction_en": { + "codeName": note.code_name_en or note.code_name or "", + "releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0, + "upgradePosition": note.upgrade_position_en or note.upgrade_position or "", + "coreUpgrades": note.core_upgrades_en or [] + } + } + + return note.version, version_info + + except Exception as e: + import traceback + traceback.print_exc() + return None, None + @staticmethod def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]: """ - 从数据库获取版本说明(优先读取已发布的版本) + 从数据库获取指定版本说明(优先读取已发布的版本) 使用反射方式读取表结构,不依赖 premium 模型类 Args: @@ -207,11 +260,8 @@ class HomePageRepository: 如果数据库中没有该版本,返回 None """ try: - from sqlalchemy import Table, MetaData - metadata = MetaData() version_notes = Table('version_notes', metadata, autoload_with=db.engine) - version_note_items = Table('version_note_items', metadata, autoload_with=db.engine) note = db.query(version_notes).filter( version_notes.c.version == version, @@ -221,31 +271,18 @@ class HomePageRepository: if not note: return None - items = db.query(version_note_items).filter( - version_note_items.c.note_id == note.id - ).order_by(version_note_items.c.sort_order).all() - - core_upgrades = [] - for item in items: - title = item.title - content = item.content - if content: - core_upgrades.append(f"{title}
{content}") - else: - core_upgrades.append(title) - return { "introduction": { - "codeName": "", - "releaseDate": note.release_date.isoformat() if note.release_date else "", - "upgradePosition": "", - "coreUpgrades": core_upgrades + "codeName": note.code_name or "", + "releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0, + "upgradePosition": note.upgrade_position or "", + "coreUpgrades": note.core_upgrades or [] }, "introduction_en": { - "codeName": "", - "releaseDate": note.release_date.isoformat() if note.release_date else "", - "upgradePosition": "", - "coreUpgrades": core_upgrades + "codeName": note.code_name_en or note.code_name or "", + "releaseDate": int(datetime.combine(note.release_date, time()).timestamp() * 1000) if note.release_date else 0, + "upgradePosition": note.upgrade_position_en or note.upgrade_position or "", + "coreUpgrades": note.core_upgrades_en or [] } } except Exception: diff --git a/api/app/repositories/implicit_emotions_storage_repository.py b/api/app/repositories/implicit_emotions_storage_repository.py index b6c40b40..b665924d 100644 --- a/api/app/repositories/implicit_emotions_storage_repository.py +++ b/api/app/repositories/implicit_emotions_storage_repository.py @@ -5,16 +5,9 @@ Implicit Emotions Storage Repository 事务由调用方控制,仓储层只使用 flush/refresh """ import logging -from datetime import date, datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Generator, Optional - -class TimeFilterUnavailableError(Exception): - """redis_client 不可用,无法执行时间轴筛选。 - - 调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。 - """ - import redis from sqlalchemy import exists, not_, select from sqlalchemy.orm import Session @@ -25,6 +18,13 @@ from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage logger = logging.getLogger(__name__) +class TimeFilterUnavailableError(Exception): + """redis_client 不可用,无法执行时间轴筛选。 + + 调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。 + """ + + class ImplicitEmotionsStorageRepository: """隐性记忆和情绪存储仓储类""" @@ -216,9 +216,7 @@ class ImplicitEmotionsStorageRepository: """ from sqlalchemy import String as SAString from sqlalchemy import cast - CST = timezone(timedelta(hours=8)) - now_cst = datetime.now(CST) - today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None) + today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) tomorrow_start = today_start + timedelta(days=1) offset = 0 while True: diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 8c477d39..03870b4d 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -263,16 +263,15 @@ class ModelConfigRepository: raise @staticmethod - def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]: - """根据类型获取模型配置""" - db_logger.debug(f"根据类型查询模型配置: type={model_type}, tenant_id={tenant_id}, is_active={is_active}") - + def get_by_type(db: Session, model_types: List[ModelType], tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]: + """根据类型获取模型配置,支持多类型查询""" + db_logger.debug(f"根据类型查询模型配置: types={[t.value for t in model_types]}, tenant_id={tenant_id}, is_active={is_active}") + try: query = db.query(ModelConfig).options( joinedload(ModelConfig.api_keys) - ).filter(ModelConfig.type == model_type) - - # 添加租户过滤 + ).filter(ModelConfig.type.in_([t.value for t in model_types])) + if tenant_id: query = query.filter( or_( @@ -280,16 +279,18 @@ class ModelConfigRepository: ModelConfig.is_public ) ) - + if is_active: query = query.filter(ModelConfig.is_active) - - models = query.order_by(ModelConfig.name).all() + + query = query.filter(ModelConfig.is_composite == False) + + models = query.order_by(ModelConfig.created_at.desc()).all() db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}") return models - + except Exception as e: - db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}") + db_logger.error(f"根据类型查询模型配置失败: types={model_types} - {str(e)}") raise @staticmethod diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 5132aa09..7caeea8a 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -1,17 +1,17 @@ -import asyncio from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + async def create_fulltext_indexes(): """Create full-text indexes for keyword search with BM25 scoring.""" connector = Neo4jConnector() try: - # 创建 Statements 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # # 创建 Dialogues 索引 # await connector.execute_query(""" # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] @@ -21,27 +21,35 @@ async def create_fulltext_indexes(): await connector.execute_query(""" CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # 创建 Chunks 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # 创建 MemorySummary 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) + """) # 创建 Community 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) - + + # 创建 Perceptual 感知记忆索引 + await connector.execute_query(""" + CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain] + OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } + """) + finally: await connector.close() + + async def create_vector_indexes(): """Create vector indexes for fast embedding similarity search. @@ -50,8 +58,7 @@ async def create_vector_indexes(): """ connector = Neo4jConnector() try: - - + # Statement embedding index await connector.execute_query(""" CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS @@ -62,8 +69,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - - + # Chunk embedding index await connector.execute_query(""" CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS @@ -75,7 +81,6 @@ async def create_vector_indexes(): }} """) - # Entity name embedding index await connector.execute_query(""" CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS @@ -86,8 +91,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - - + # Memory summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS @@ -98,7 +102,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - + # Community summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS @@ -108,8 +112,8 @@ async def create_vector_indexes(): `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """) - + """) + # Dialogue embedding index (optional) await connector.execute_query(""" CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS @@ -120,15 +124,27 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - + + # Perceptual summary embedding index + await connector.execute_query(""" + CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS + FOR (p:Perceptual) + ON p.summary_embedding + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """) finally: await connector.close() + + async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. Ensures concurrent MERGE operations remain safe and prevents duplicates. """ connector = Neo4jConnector() - try: + try: # Dialogue.id unique await connector.execute_query( """ @@ -136,7 +152,7 @@ async def create_unique_constraints(): FOR (d:Dialogue) REQUIRE d.id IS UNIQUE """ ) - + # Statement.id unique await connector.execute_query( """ @@ -144,7 +160,7 @@ async def create_unique_constraints(): FOR (s:Statement) REQUIRE s.id IS UNIQUE """ ) - + # Chunk.id unique await connector.execute_query( """ @@ -152,13 +168,13 @@ async def create_unique_constraints(): FOR (c:Chunk) REQUIRE c.id IS UNIQUE """ ) - + finally: await connector.close() + + async def create_all_indexes(): """Create all indexes and constraints in one go.""" await create_fulltext_indexes() await create_vector_indexes() await create_unique_constraints() - print("✓ All indexes and constraints created successfully!") - diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 26ffe350..daf04bcb 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -23,6 +23,7 @@ SET s += { end_user_id: statement.end_user_id, stmt_type: statement.stmt_type, statement: statement.statement, + speaker: statement.speaker, emotion_intensity: statement.emotion_intensity, emotion_target: statement.emotion_target, emotion_subject: statement.emotion_subject, @@ -56,6 +57,7 @@ SET c += { expired_at: chunk.expired_at, dialog_id: chunk.dialog_id, content: chunk.content, + speaker: chunk.speaker, chunk_embedding: chunk.chunk_embedding, sequence_number: chunk.sequence_number, start_index: chunk.start_index, @@ -91,6 +93,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity END, e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END, e.aliases = CASE + // 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,知识抽取完全不写入 + WHEN entity.name IN ['用户', '我', 'User', 'I'] THEN e.aliases WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0 THEN CASE WHEN e.aliases IS NULL THEN entity.aliases @@ -283,7 +287,7 @@ LIMIT $limit """ SEARCH_STATEMENTS_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score +CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) @@ -307,7 +311,7 @@ LIMIT $limit """ # 查询实体名称包含指定字符串的实体 SEARCH_ENTITIES_BY_NAME = """ -CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score +CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) @@ -337,21 +341,21 @@ LIMIT $limit """ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ -CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score +CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WITH e, score -WITH collect({entity: e, score: score}) AS fulltextResults +With collect({entity: e, score: score}) AS fulltextResults OPTIONAL MATCH (ae:ExtractedEntity) WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id) AND ae.aliases IS NOT NULL - AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q)) + AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query)) WITH fulltextResults, collect(ae) AS aliasEntities UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score: CASE - WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 - WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9 ELSE 0.8 END }]) AS row @@ -384,7 +388,7 @@ LIMIT $limit SEARCH_CHUNKS_BY_CONTENT = """ -CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score +CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) @@ -501,7 +505,7 @@ LIMIT $limit """ SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score +CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) @@ -677,7 +681,7 @@ SET n.invalid_at = $new_invalid_at # MemorySummary keyword search using fulltext index SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score +CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) RETURN m.id AS id, @@ -1363,7 +1367,7 @@ RETURN c.community_id AS community_id # Community keyword search: matches name or summary via fulltext index SEARCH_COMMUNITIES_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score +CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) RETURN c.community_id AS id, c.name AS name, @@ -1449,3 +1453,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id, r.created_at = edge.created_at RETURN elementId(r) AS uuid """ + +SEARCH_PERCEPTUAL_BY_KEYWORD = """ +CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score +WHERE p.end_user_id = $end_user_id +RETURN p.id AS id, + p.end_user_id AS end_user_id, + p.perceptual_type AS perceptual_type, + p.file_path AS file_path, + p.file_name AS file_name, + p.file_ext AS file_ext, + p.summary AS summary, + p.keywords AS keywords, + p.topic AS topic, + p.domain AS domain, + p.created_at AS created_at, + p.file_type AS file_type, + score +ORDER BY score DESC +LIMIT $limit +""" + +PERCEPTUAL_EMBEDDING_SEARCH = """ +CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding) +YIELD node AS p, score +WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id +RETURN p.id AS id, + p.end_user_id AS end_user_id, + p.perceptual_type AS perceptual_type, + p.file_path AS file_path, + p.file_name AS file_name, + p.file_ext AS file_ext, + p.summary AS summary, + p.keywords AS keywords, + p.topic AS topic, + p.domain AS domain, + p.created_at AS created_at, + p.file_type AS file_type, + score +ORDER BY score DESC +LIMIT $limit +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index adc266fe..56feece2 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -186,6 +186,58 @@ async def save_dialog_and_statements_to_neo4j( Returns: bool: True if successful, False otherwise """ + # TODO 需要在去重消歧节阶段,做以下逻辑的处理 + # 预处理:对特殊实体("用户"、"AI助手")复用 Neo4j 中已有节点的 ID, + # 确保同一个 end_user_id 下只有一个"用户"节点和一个"AI助手"节点。 + if entity_nodes: + _SPECIAL_NAMES = {"用户", "我", "user", "i", "ai助手", "助手", "ai assistant", "assistant"} + end_user_id = entity_nodes[0].end_user_id if entity_nodes else None + if end_user_id: + try: + # 查询已有的特殊实体 + cypher = """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $names + RETURN e.id AS id, e.name AS name + """ + existing = await connector.execute_query( + cypher, + end_user_id=end_user_id, + names=list(_SPECIAL_NAMES), + ) + # 建立 name(lower) → existing_id 映射 + existing_id_map = {} + for record in (existing or []): + name_lower = (record.get("name") or "").strip().lower() + if name_lower and record.get("id"): + existing_id_map[name_lower] = record["id"] + + if existing_id_map: + # 替换新实体的 ID 为已有 ID,同时更新所有引用该 ID 的边 + for ent in entity_nodes: + name_lower = (ent.name or "").strip().lower() + if name_lower in existing_id_map: + old_id = ent.id + new_id = existing_id_map[name_lower] + if old_id != new_id: + ent.id = new_id + # 更新 statement_entity_edges 中的引用 + for edge in statement_entity_edges: + if edge.target == old_id: + edge.target = new_id + if edge.source == old_id: + edge.source = new_id + # 更新 entity_edges 中的引用 + for edge in entity_edges: + if edge.source == old_id: + edge.source = new_id + if edge.target == old_id: + edge.target = new_id + logger.info( + f"特殊实体 '{ent.name}' ID 复用: {old_id[:8]}... → {new_id[:8]}..." + ) + except Exception as e: + logger.warning(f"特殊实体 ID 复用查询失败(不影响写入): {e}") # 定义事务函数,将所有写操作放在一个事务中 async def _save_all_in_transaction(tx): diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index c5d3bcca..a191dad6 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -2,12 +2,14 @@ import asyncio import logging from typing import Any, Dict, List, Optional +from app.core.memory.utils.data.text_utils import escape_lucene_query from app.repositories.neo4j.cypher_queries import ( CHUNK_EMBEDDING_SEARCH, COMMUNITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH, EXPAND_COMMUNITY_STATEMENTS, MEMORY_SUMMARY_EMBEDDING_SEARCH, + PERCEPTUAL_EMBEDDING_SEARCH, SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNKS_BY_CONTENT, SEARCH_COMMUNITIES_BY_KEYWORD, @@ -15,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import ( SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + SEARCH_PERCEPTUAL_BY_KEYWORD, SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -34,11 +37,11 @@ logger = logging.getLogger(__name__) async def _update_activation_values_batch( - connector: Neo4jConnector, - nodes: List[Dict[str, Any]], - node_label: str, - end_user_id: Optional[str] = None, - max_retries: int = 3 + connector: Neo4jConnector, + nodes: List[Dict[str, Any]], + node_label: str, + end_user_id: Optional[str] = None, + max_retries: int = 3 ) -> List[Dict[str, Any]]: """ 批量更新节点的激活值 @@ -58,7 +61,7 @@ async def _update_activation_values_batch( """ if not nodes: return [] - + # 延迟导入以避免循环依赖 from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( AccessHistoryManager, @@ -66,7 +69,7 @@ async def _update_activation_values_batch( from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( ACTRCalculator, ) - + # 创建计算器和管理器实例 actr_calculator = ACTRCalculator() access_manager = AccessHistoryManager( @@ -74,7 +77,7 @@ async def _update_activation_values_batch( actr_calculator=actr_calculator, max_retries=max_retries ) - + # 提取节点ID列表并去重(保持原始顺序) seen_ids = set() unique_node_ids = [] @@ -83,9 +86,9 @@ async def _update_activation_values_batch( if node_id and node_id not in seen_ids: seen_ids.add(node_id) unique_node_ids.append(node_id) - + if not unique_node_ids: - logger.warning(f"批量更新激活值:没有有效的节点ID") + logger.warning("批量更新激活值:没有有效的节点ID") return nodes # 记录去重信息(仅针对具有有效 ID 的节点) @@ -95,7 +98,7 @@ async def _update_activation_values_batch( f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, " f"去重后唯一ID数量={len(unique_node_ids)}" ) - + # 批量记录访问 try: updated_nodes = await access_manager.record_batch_access( @@ -103,14 +106,14 @@ async def _update_activation_values_batch( node_label=node_label, end_user_id=end_user_id ) - + logger.info( f"批量更新激活值成功: {node_label}, " f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}" ) - + return updated_nodes - + except Exception as e: logger.error( f"批量更新激活值失败: {node_label}, 错误: {str(e)}" @@ -120,9 +123,9 @@ async def _update_activation_values_batch( async def _update_search_results_activation( - connector: Neo4jConnector, - results: Dict[str, List[Dict[str, Any]]], - end_user_id: Optional[str] = None + connector: Neo4jConnector, + results: Dict[str, List[Dict[str, Any]]], + end_user_id: Optional[str] = None ) -> Dict[str, List[Dict[str, Any]]]: """ 更新搜索结果中所有知识节点的激活值 @@ -144,11 +147,11 @@ async def _update_search_results_activation( 'entities': 'ExtractedEntity', 'summaries': 'MemorySummary' } - + # 并行更新所有类型的节点 update_tasks = [] update_keys = [] - + for key, label in knowledge_node_types.items(): if key in results and results[key]: update_tasks.append( @@ -160,13 +163,13 @@ async def _update_search_results_activation( ) ) update_keys.append(key) - + if not update_tasks: return results - + # 并行执行所有更新 update_results = await asyncio.gather(*update_tasks, return_exceptions=True) - + # 更新结果字典,保留原始搜索分数 updated_results = results.copy() for key, update_result in zip(update_keys, update_results): @@ -175,10 +178,10 @@ async def _update_search_results_activation( # 保留原始的 score 字段(BM25/Embedding 分数) original_nodes = results[key] updated_nodes = update_result - + # 创建 ID 到更新节点的映射(用于快速查找激活值数据) updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')} - + # 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充 merged_nodes = [] for original_node in original_nodes: @@ -186,7 +189,7 @@ async def _update_search_results_activation( if node_id and node_id in updated_map: # 从原始节点开始,用更新后的激活值数据覆盖 merged_node = original_node.copy() - + # 更新激活值相关字段 activation_fields = { 'activation_value', @@ -196,35 +199,35 @@ async def _update_search_results_activation( 'importance_score', 'version', 'statement', # Statement 节点的内容字段 - 'content' # MemorySummary 节点的内容字段 + 'content' # MemorySummary 节点的内容字段 } - + # 只更新激活值相关字段,保留原始节点的其他字段 for field in activation_fields: if field in updated_map[node_id]: merged_node[field] = updated_map[node_id][field] - + merged_nodes.append(merged_node) else: # 如果没有更新数据,保留原始节点 merged_nodes.append(original_node) - + updated_results[key] = merged_nodes else: # 更新失败,记录错误但保留原始结果 logger.warning( f"更新 {key} 激活值失败: {str(update_result)}" ) - + return updated_results async def search_graph( - connector: Neo4jConnector, - q: str, - end_user_id: Optional[str] = None, - limit: int = 50, - include: List[str] = None, + connector: Neo4jConnector, + query: str, + end_user_id: Optional[str] = None, + limit: int = 50, + include: List[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: """ Search across Statements, Entities, Chunks, and Summaries using a free-text query. @@ -232,14 +235,14 @@ async def search_graph( OPTIMIZED: Runs all queries in parallel using asyncio.gather() INTEGRATED: Updates activation values for knowledge nodes before returning results - - Statements: matches s.statement CONTAINS q - - Entities: matches e.name CONTAINS q - - Chunks: matches s.content CONTAINS q (from Statement nodes) - - Summaries: matches ms.content CONTAINS q + - Statements: matches s.statement CONTAINS query + - Entities: matches e.name CONTAINS query + - Chunks: matches s.content CONTAINS query (from Statement nodes) + - Summaries: matches ms.content CONTAINS query Args: connector: Neo4j connector - q: Query text + query: Query text for full-text search end_user_id: Optional group filter limit: Max results per category include: List of categories to search (default: all) @@ -249,42 +252,49 @@ async def search_graph( """ if include is None: include = ["statements", "chunks", "entities", "summaries"] - + + # Escape Lucene special characters to prevent query parse errors + escaped_query = escape_lucene_query(query) + # Prepare tasks for parallel execution tasks = [] task_keys = [] - + if "statements" in include: tasks.append(connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD, - q=q, + json_format=True, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") - + if "entities" in include: tasks.append(connector.execute_query( SEARCH_ENTITIES_BY_NAME_OR_ALIAS, - q=q, + json_format=True, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") - + if "chunks" in include: tasks.append(connector.execute_query( SEARCH_CHUNKS_BY_CONTENT, - q=q, + json_format=True, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") - + if "summaries" in include: tasks.append(connector.execute_query( SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - q=q, + json_format=True, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) @@ -293,15 +303,16 @@ async def search_graph( if "communities" in include: tasks.append(connector.execute_query( SEARCH_COMMUNITIES_BY_KEYWORD, - q=q, + json_format=True, + query=escaped_query, end_user_id=end_user_id, limit=limit, )) task_keys.append("communities") - + # Execute all queries in parallel task_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Build results dictionary results = {} for key, result in zip(task_keys, task_results): @@ -310,14 +321,14 @@ async def search_graph( results[key] = [] else: results[key] = result - + # Deduplicate results before updating activation values # This prevents duplicates from propagating through the pipeline from app.core.memory.src.search import _deduplicate_results for key in results: if isinstance(results[key], list): results[key] = _deduplicate_results(results[key]) - + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) # Skip activation updates if only searching summaries (optimization) needs_activation_update = any( @@ -331,17 +342,17 @@ async def search_graph( results=results, end_user_id=end_user_id ) - + return results async def search_graph_by_embedding( - connector: Neo4jConnector, - embedder_client, - query_text: str, - end_user_id: Optional[str] = None, - limit: int = 50, - include: List[str] = ["statements", "chunks", "entities","summaries"], + connector: Neo4jConnector, + embedder_client, + query_text: str, + end_user_id: Optional[str] = None, + limit: int = 50, + include: List[str] = ["statements", "chunks", "entities", "summaries"], ) -> Dict[str, List[Dict[str, Any]]]: """ Embedding-based semantic search across Statements, Chunks, and Entities. @@ -355,13 +366,13 @@ async def search_graph_by_embedding( - Returns up to 'limit' per included type """ import time - + # Get embedding for the query embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") - + logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s") + if not embeddings or not embeddings[0]: logger.warning( f"search_graph_by_embedding: embedding 生成失败或为空," @@ -378,6 +389,7 @@ async def search_graph_by_embedding( if "statements" in include: tasks.append(connector.execute_query( STATEMENT_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -388,6 +400,7 @@ async def search_graph_by_embedding( if "chunks" in include: tasks.append(connector.execute_query( CHUNK_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -398,6 +411,7 @@ async def search_graph_by_embedding( if "entities" in include: tasks.append(connector.execute_query( ENTITY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -408,6 +422,7 @@ async def search_graph_by_embedding( if "summaries" in include: tasks.append(connector.execute_query( MEMORY_SUMMARY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -418,6 +433,7 @@ async def search_graph_by_embedding( if "communities" in include: tasks.append(connector.execute_query( COMMUNITY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -428,8 +444,8 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") - + logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { "statements": [], @@ -438,7 +454,7 @@ async def search_graph_by_embedding( "summaries": [], "communities": [], } - + for key, result in zip(task_keys, task_results): if isinstance(result, Exception): logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}") @@ -470,16 +486,18 @@ async def search_graph_by_embedding( update_time = time.time() - update_start logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") else: - logger.info(f"[PERF] Skipping activation updates (only summaries)") + logger.info("[PERF] Skipping activation updates (only summaries)") return results + + async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 - connector: Neo4jConnector, - end_user_id: str, - entities: List[Dict[str, Any]], - use_contains_fallback: bool = True, - batch_size: int = 500, - max_concurrency: int = 5, + connector: Neo4jConnector, + end_user_id: str, + entities: List[Dict[str, Any]], + use_contains_fallback: bool = True, + batch_size: int = 500, + max_concurrency: int = 5, ) -> Dict[str, List[Dict[str, Any]]]: """ 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries): @@ -506,7 +524,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 # 全文索引按名称检索(包含 CONTAINS 语义) rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, - q=name, + query=escape_lucene_query(name), end_user_id=end_user_id, limit=100, ) @@ -530,7 +548,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 try: rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, - q=name.lower(), + query=escape_lucene_query(name.lower()), end_user_id=end_user_id, limit=100, ) @@ -560,14 +578,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 async def search_graph_by_keyword_temporal( - connector: Neo4jConnector, - query_text: str, - end_user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 50, + connector: Neo4jConnector, + query_text: str, + end_user_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 50, ) -> Dict[str, List[Any]]: """ Temporal keyword search across Statements. @@ -579,11 +597,12 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - print(f"query_text不能为空") + logger.warning("query_text不能为空") return {"statements": []} + escaped_query = escape_lucene_query(query_text) statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - q=query_text, + query=escaped_query, end_user_id=end_user_id, start_date=start_date, end_date=end_date, @@ -591,7 +610,7 @@ async def search_graph_by_keyword_temporal( invalid_date=invalid_date, limit=limit, ) - print(f"查询结果为:\n{statements}") + logger.debug(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -605,13 +624,13 @@ async def search_graph_by_keyword_temporal( async def search_graph_by_temporal( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 10, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 10, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -632,10 +651,6 @@ async def search_graph_by_temporal( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -643,15 +658,15 @@ async def search_graph_by_temporal( results=results, end_user_id=end_user_id ) - + return results async def search_graph_by_dialog_id( - connector: Neo4jConnector, - dialog_id: str, - end_user_id: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + dialog_id: str, + end_user_id: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Dialogues. @@ -661,7 +676,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - print(f"dialog_id不能为空") + logger.warning("dialog_id不能为空") return {"dialogues": []} dialogues = await connector.execute_query( @@ -674,13 +689,13 @@ async def search_graph_by_dialog_id( async def search_graph_by_chunk_id( - connector: Neo4jConnector, - chunk_id : str, - end_user_id: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + chunk_id: str, + end_user_id: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - print(f"chunk_id不能为空") + logger.warning("chunk_id不能为空") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -692,10 +707,10 @@ async def search_graph_by_chunk_id( async def search_graph_community_expand( - connector: Neo4jConnector, - community_ids: List[str], - end_user_id: str, - limit: int = 10, + connector: Neo4jConnector, + community_ids: List[str], + end_user_id: str, + limit: int = 10, ) -> Dict[str, List[Dict[str, Any]]]: """ 三期:社区展开检索 —— 主题 → 细节两级检索。 @@ -748,12 +763,11 @@ async def search_graph_community_expand( async def search_graph_by_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -767,16 +781,11 @@ async def search_graph_by_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_BY_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -784,16 +793,16 @@ async def search_graph_by_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_by_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -807,16 +816,11 @@ async def search_graph_by_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_BY_VALID_AT, end_user_id=end_user_id, - - + valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -824,16 +828,16 @@ async def search_graph_by_valid_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_g_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -847,16 +851,11 @@ async def search_graph_g_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_G_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -864,16 +863,16 @@ async def search_graph_g_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_g_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -887,16 +886,10 @@ async def search_graph_g_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_G_VALID_AT, end_user_id=end_user_id, - - valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -904,16 +897,16 @@ async def search_graph_g_valid_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_l_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -927,16 +920,11 @@ async def search_graph_l_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_L_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -944,16 +932,16 @@ async def search_graph_l_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_l_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -967,16 +955,11 @@ async def search_graph_l_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_L_VALID_AT, end_user_id=end_user_id, - - + valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -984,5 +967,89 @@ async def search_graph_l_valid_at( results=results, end_user_id=end_user_id ) - + return results + + +async def search_perceptual( + connector: Neo4jConnector, + query: str, + end_user_id: Optional[str] = None, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + Search Perceptual memory nodes using fulltext keyword search. + + Matches against summary, topic, and domain fields via the perceptualFulltext index. + + Args: + connector: Neo4j connector + query: Query text for full-text search + end_user_id: Optional user filter + limit: Max results + + Returns: + Dictionary with 'perceptuals' key containing matched perceptual memory nodes + """ + try: + perceptuals = await connector.execute_query( + SEARCH_PERCEPTUAL_BY_KEYWORD, + query=escape_lucene_query(query), + end_user_id=end_user_id, + limit=limit, + ) + except Exception as e: + logger.warning(f"search_perceptual: keyword search failed: {e}") + perceptuals = [] + + # Deduplicate + from app.core.memory.src.search import _deduplicate_results + perceptuals = _deduplicate_results(perceptuals) + + return {"perceptuals": perceptuals} + + +async def search_perceptual_by_embedding( + connector: Neo4jConnector, + embedder_client, + query_text: str, + end_user_id: Optional[str] = None, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + Search Perceptual memory nodes using embedding-based semantic search. + + Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index. + + Args: + connector: Neo4j connector + embedder_client: Embedding client with async response() method + query_text: Query text to embed + end_user_id: Optional user filter + limit: Max results + + Returns: + Dictionary with 'perceptuals' key containing matched perceptual memory nodes + """ + embeddings = await embedder_client.response([query_text]) + if not embeddings or not embeddings[0]: + logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'") + return {"perceptuals": []} + + embedding = embeddings[0] + + try: + perceptuals = await connector.execute_query( + PERCEPTUAL_EMBEDDING_SEARCH, + embedding=embedding, + end_user_id=end_user_id, + limit=limit, + ) + except Exception as e: + logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}") + perceptuals = [] + + from app.core.memory.src.search import _deduplicate_results + perceptuals = _deduplicate_results(perceptuals) + + return {"perceptuals": perceptuals} diff --git a/api/app/repositories/neo4j/neo4j_connector.py b/api/app/repositories/neo4j/neo4j_connector.py index d96e4431..d20bf75f 100644 --- a/api/app/repositories/neo4j/neo4j_connector.py +++ b/api/app/repositories/neo4j/neo4j_connector.py @@ -11,10 +11,28 @@ Classes: from typing import Any, List, Dict from neo4j import AsyncGraphDatabase, basic_auth +from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration from app.core.config import settings +def _convert_neo4j_types(value: Any) -> Any: + """递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。""" + if isinstance(value, Neo4jDateTime): + return value.to_native().isoformat() if value.tzinfo else value.iso_format() + if isinstance(value, Neo4jDate): + return value.iso_format() + if isinstance(value, Neo4jTime): + return value.iso_format() + if isinstance(value, Neo4jDuration): + return str(value) + if isinstance(value, dict): + return {k: _convert_neo4j_types(v) for k, v in value.items()} + if isinstance(value, list): + return [_convert_neo4j_types(item) for item in value] + return value + + class Neo4jConnector: """Neo4j数据库连接器 @@ -59,11 +77,12 @@ class Neo4jConnector: """ await self.driver.close() - async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + async def execute_query(self, cypher: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]: """执行Cypher查询 Args: - query: Cypher查询语句 + cypher: Cypher查询语句 + json_format: json格式化 **kwargs: 查询参数,将作为参数传递给Cypher查询 Returns: @@ -73,12 +92,15 @@ class Neo4jConnector: """ result = await self.driver.execute_query( - query, + cypher, database="neo4j", **kwargs ) records, summary, keys = result - return [record.data() for record in records] + if json_format: + return [_convert_neo4j_types(record.data()) for record in records] + else: + return [record.data() for record in records] async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: """在写事务中执行操作 diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 1a9b0b87..1348c4e8 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -161,6 +161,17 @@ class BuiltinToolRepository: BuiltinToolConfig.id == tool_id ).first() + @staticmethod + def get_existing_tool_classes(db: Session, tenant_id: uuid.UUID) -> set: + """获取该租户已有的内置工具 tool_class 集合""" + rows = db.query(BuiltinToolConfig.tool_class).join( + ToolConfig, BuiltinToolConfig.id == ToolConfig.id + ).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == ToolType.BUILTIN.value + ).all() + return {row[0] for row in rows} + class CustomToolRepository: """自定义工具仓储类""" diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index af4449e5..6874f9bf 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -23,7 +23,7 @@ class UserRepository: db_logger.debug(f"根据 ID 查询用户:user_id={user_id}") try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id).first() + user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id, User.is_active.is_(True)).first() if user: # 检查租户状态,租户禁用时返回 None if user.tenant and not user.tenant.is_active: @@ -297,6 +297,10 @@ def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]: """根据ID获取用户""" return UserRepository(db).get_user_by_id(user_id) +def get_user_by_id_regardless_active(db: Session, user_id: uuid.UUID) -> Optional[User]: + """根据ID获取用户(不过滤 is_active,用于启用/禁用场景)""" + return db.query(User).filter(User.id == user_id).first() + def get_user_by_email(db: Session, email: str) -> Optional[User]: """根据邮箱获取用户""" return UserRepository(db).get_user_by_email(email) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index f1e9132f..e93c513d 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -44,6 +44,8 @@ class FileInput(BaseModel): upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)") url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + name: Optional[str] = Field(None, description="文件名") + size: Optional[int] = Field(None, description="文件大小(字节)") _content = None @@ -241,6 +243,9 @@ class ModelParameters(BaseModel): presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚") n: int = Field(default=1, ge=1, le=10, description="生成的回复数量") stop: Optional[List[str]] = Field(default=None, description="停止序列") + deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)") + thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)") + json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)") class VariableDefinition(BaseModel): @@ -612,7 +617,9 @@ class AppChatRequest(BaseModel): user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") + thinking: bool = Field(default=False, description="是否启用深度思考(需Agent配置支持)") files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)") + version: Optional[uuid.UUID] = Field(default=None, description="指定发布版本ID,不传则使用当前生效版本") class DraftRunRequest(BaseModel): @@ -641,6 +648,7 @@ class CitationSource(BaseModel): class DraftRunResponse(BaseModel): """试运行响应(非流式)""" message: str = Field(..., description="AI 回复消息") + reasoning_content: Optional[str] = Field(default=None, description="深度思考内容") conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况") elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)") @@ -648,6 +656,12 @@ class DraftRunResponse(BaseModel): citations: List[CitationSource] = Field(default_factory=list, description="引用来源") audio_url: Optional[str] = Field(default=None, description="TTS 语音URL") + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + if not data.get("reasoning_content"): + data.pop("reasoning_content", None) + return data + class OpeningResponse(BaseModel): """应用开场白响应""" diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 98715612..fd1be5d9 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -31,7 +31,8 @@ class ChatRequest(BaseModel): stream: bool = Field(default=False, description="是否流式返回") web_search: bool = Field(default=False, description="是否启用网络搜索") memory: bool = Field(default=True, description="是否启用记忆功能") - files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)") + thinking: bool = Field(default=False, description="是否启用深度思考(需Agent配置支持)") + files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)") # ---------- Output Schemas ---------- diff --git a/api/app/services/api_key_service.py b/api/app/services/api_key_service.py index a49e8fe0..07d55198 100644 --- a/api/app/services/api_key_service.py +++ b/api/app/services/api_key_service.py @@ -248,6 +248,35 @@ class RateLimiterService: def __init__(self): self.redis = aio_redis + async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: + """ + 按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit + """ + now = time.time() + window_start = now - 1 # 1 秒窗口 + key = f"rate_limit:tenant_qps:{tenant_id}" + + async with self.redis.pipeline() as pipe: + # 清理 1 秒前的旧记录 + pipe.zremrangebyscore(key, 0, window_start) + # 加入当前请求(score=时间戳,member=时间戳+随机数保证唯一) + pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now}) + # 统计窗口内请求数 + pipe.zcard(key) + # 设置 key 过期(2 秒后自动清理) + pipe.expire(key, 2) + results = await pipe.execute() + + current = results[2] + remaining = max(0, limit - current) + reset_time = int(now) + 1 + + return current <= limit, { + "limit": limit, + "remaining": remaining, + "reset": reset_time, + } + async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: """ 检查QPS限制 diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index df81568f..56e25713 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -26,6 +26,7 @@ from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService +from app.models.file_metadata_model import FileMetadata logger = get_business_logger() @@ -117,7 +118,10 @@ class AppChatService: max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, - + deep_thinking=model_parameters.get("deep_thinking", False), + thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), + capability=api_key_obj.capability or [], ) model_info = ModelInfo( @@ -163,7 +167,14 @@ class AppChatService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - + # 为需要运行时上下文的工具注入上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 调用 Agent(支持多模态) result = await agent.chat( message=message, @@ -205,14 +216,33 @@ class AppChatService: "model": api_key_obj.model_name, "usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), "audio_url": None, - "citations": filtered_citations + "citations": filtered_citations, + "reasoning_content": result.get("reasoning_content") } if files: + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: - # url = await MultimodalService(self.db).get_file_url(f) + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "name": name, + "size": size, + "file_type": f.file_type, }) if processed_files: @@ -228,8 +258,13 @@ class AppChatService: if memory_flag: connected_config = get_end_user_connected_config(user_id, self.db) memory_config_id: str = connected_config.get("memory_config_id") + file_list = [] + for file in files: + file_dict = file.model_dump() + file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None + file_list.append(file_dict) messages = [ - {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "user", "content": message, "files": file_list}, {"role": "assistant", "content": result["content"]} ] if memory_config_id: @@ -258,6 +293,7 @@ class AppChatService: "conversation_id": conversation_id, "message_id": str(message_id), "message": result["content"], + "reasoning_content": result.get("reasoning_content"), "usage": result.get("usage", { "prompt_tokens": 0, "completion_tokens": 0, @@ -354,7 +390,11 @@ class AppChatService: max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, - streaming=True + streaming=True, + deep_thinking=model_parameters.get("deep_thinking", False), + thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), + capability=api_key_obj.capability or [], ) model_info = ModelInfo( @@ -401,8 +441,18 @@ class AppChatService: processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") + # 为需要运行时上下文的工具注入上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) + # 流式调用 Agent(支持多模态),同时并行启动 TTS full_content = "" + full_reasoning = "" total_tokens = 0 text_queue: asyncio.Queue = asyncio.Queue() @@ -426,6 +476,9 @@ class AppChatService: ): if isinstance(chunk, int): total_tokens = chunk + elif isinstance(chunk, dict) and chunk.get("type") == "reasoning": + full_reasoning += chunk['content'] + yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n" else: full_content += chunk yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" @@ -472,14 +525,34 @@ class AppChatService: "model": api_key_obj.model_name, "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}, "audio_url": None, - "citations": filtered_citations + "citations": filtered_citations, + "reasoning_content": full_reasoning or None } if files: + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "name": name, + "size": size, + "file_type": f.file_type, }) if processed_files: human_meta["history_files"] = { @@ -494,8 +567,13 @@ class AppChatService: if memory_flag: connected_config = get_end_user_connected_config(user_id, self.db) memory_config_id: str = connected_config.get("memory_config_id") + file_list = [] + for file in files: + file_dict = file.model_dump() + file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None + file_list.append(file_dict) messages = [ - {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "user", "content": message, "files": file_list}, {"role": "assistant", "content": full_content} ] if memory_config_id: @@ -652,13 +730,13 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - if "sub_usage" in event: + # 拦截 sub_usage 事件,累加 token + if "event: sub_usage" in event: if "data:" in event: try: data_line = event.split("data: ", 1)[1].strip() data = json.loads(data_line) - if "total_tokens" in data: - total_tokens += data["total_tokens"] + total_tokens += data.get("total_tokens", 0) except: pass else: diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index 8c198be4..c527e4d3 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -73,15 +73,14 @@ class AppDslService: AppType.MULTI_AGENT: "multi_agent_config", AppType.WORKFLOW: "workflow" }.get(app.type, "config") - config_data = self._enrich_release_config(app.type, release.config or {}) + config_data = self._enrich_release_config(app.type, release.config or {}, release.default_model_config_id) dsl = {**meta, "app": app_meta, config_key: config_data} return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml" - def _enrich_release_config(self, app_type: str, cfg: dict) -> dict: + def _enrich_release_config(self, app_type: str, cfg: dict, default_model_config_id=None) -> dict: if app_type == AppType.AGENT: enriched = {**cfg} - if "default_model_config_id" in cfg: - enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"]) + enriched["default_model_config_ref"] = self._model_ref(default_model_config_id) if "knowledge_retrieval" in cfg: enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"]) if "tools" in cfg: @@ -91,8 +90,7 @@ class AppDslService: return enriched if app_type == AppType.MULTI_AGENT: enriched = {**cfg} - if "default_model_config_id" in cfg: - enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"]) + enriched["default_model_config_ref"] = self._model_ref(default_model_config_id) if "master_agent_id" in cfg: enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"]) if "sub_agents" in cfg: @@ -229,8 +227,11 @@ class AppDslService: workspace_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID, + app_id: Optional[uuid.UUID] = None, ) -> tuple[App, list[str]]: - """解析 DSL,创建应用及配置,返回 (new_app, warnings)""" + """解析 DSL,创建或覆盖应用配置,返回 (app, warnings)。 + app_id 不为空时:校验类型一致后覆盖配置;为空时创建新应用。 + """ app_meta = dsl.get("app", {}) app_type = app_meta.get("type") if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW): @@ -239,6 +240,9 @@ class AppDslService: warnings: list[str] = [] now = datetime.datetime.now() + if app_id is not None: + return self._overwrite_dsl(dsl, app_id, app_type, workspace_id, tenant_id, warnings, now) + new_app = App( id=uuid.uuid4(), workspace_id=workspace_id, @@ -258,11 +262,57 @@ class AppDslService: self.db.add(new_app) self.db.flush() + self._write_config(new_app.id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=True) + + self.db.commit() + self.db.refresh(new_app) + return new_app, warnings + + def _overwrite_dsl( + self, + dsl: dict, + app_id: uuid.UUID, + app_type: str, + workspace_id: uuid.UUID, + tenant_id: uuid.UUID, + warnings: list, + now: datetime.datetime, + ) -> tuple[App, list[str]]: + """覆盖已有应用的配置,类型不一致时抛出异常""" + app = self.db.query(App).filter( + App.id == app_id, + App.workspace_id == workspace_id, + App.is_active.is_(True) + ).first() + if not app: + raise ResourceNotFoundException("应用", str(app_id)) + if app.type != app_type: + raise BusinessException( + f"YAML 类型 '{app_type}' 与应用类型 '{app.type}' 不一致,无法导入", + BizCode.BAD_REQUEST + ) + + self._write_config(app_id, app_type, dsl, workspace_id, tenant_id, warnings, now, create=False) + + self.db.commit() + self.db.refresh(app) + return app, warnings + + def _write_config( + self, + app_id: uuid.UUID, + app_type: str, + dsl: dict, + workspace_id: uuid.UUID, + tenant_id: uuid.UUID, + warnings: list, + now: datetime.datetime, + create: bool, + ) -> None: + """写入(新建或覆盖)应用配置""" if app_type == AppType.AGENT: cfg = dsl.get("agent_config") or {} - self.db.add(AgentConfig( - id=uuid.uuid4(), - app_id=new_app.id, + fields = dict( system_prompt=cfg.get("system_prompt"), model_parameters=cfg.get("model_parameters"), default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings), @@ -272,16 +322,21 @@ class AppDslService: tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings), skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings), features=cfg.get("features", {}), - is_active=True, - created_at=now, updated_at=now, - )) + ) + if create: + self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) + else: + existing = self.db.query(AgentConfig).filter(AgentConfig.app_id == app_id).first() + if existing: + for k, v in fields.items(): + setattr(existing, k, v) + else: + self.db.add(AgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) elif app_type == AppType.MULTI_AGENT: cfg = dsl.get("multi_agent_config") or {} - self.db.add(MultiAgentConfig( - id=uuid.uuid4(), - app_id=new_app.id, + fields = dict( orchestration_mode=cfg.get("orchestration_mode", "collaboration"), master_agent_name=cfg.get("master_agent_name"), model_parameters=cfg.get("model_parameters"), @@ -291,10 +346,17 @@ class AppDslService: routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings), execution_config=cfg.get("execution_config", {}), aggregation_strategy=cfg.get("aggregation_strategy", "merge"), - is_active=True, - created_at=now, updated_at=now, - )) + ) + if create: + self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) + else: + existing = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app_id).first() + if existing: + for k, v in fields.items(): + setattr(existing, k, v) + else: + self.db.add(MultiAgentConfig(id=uuid.uuid4(), app_id=app_id, is_active=True, created_at=now, **fields)) elif app_type == AppType.WORKFLOW: adapter = MemoryBearAdapter(dsl) @@ -306,20 +368,39 @@ class AppDslService: for w in result.warnings: warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}") wf = dsl.get("workflow") or {} - WorkflowService(self.db).create_workflow_config( - app_id=new_app.id, - nodes=[n.model_dump() for n in result.nodes], - edges=[e.model_dump() for e in result.edges], - variables=[v.model_dump() for v in result.variables], - execution_config=wf.get("execution_config", {}), - features=wf.get("features", {}), - triggers=wf.get("triggers", []), - validate=False, - ) - - self.db.commit() - self.db.refresh(new_app) - return new_app, warnings + wf_service = WorkflowService(self.db) + if create: + wf_service.create_workflow_config( + app_id=app_id, + nodes=[n.model_dump() for n in result.nodes], + edges=[e.model_dump() for e in result.edges], + variables=[v.model_dump() for v in result.variables], + execution_config=wf.get("execution_config", {}), + features=wf.get("features", {}), + triggers=wf.get("triggers", []), + validate=False, + ) + else: + existing = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app_id).first() + if existing: + existing.nodes = [n.model_dump() for n in result.nodes] + existing.edges = [e.model_dump() for e in result.edges] + existing.variables = [v.model_dump() for v in result.variables] + existing.execution_config = wf.get("execution_config", {}) + existing.features = wf.get("features", {}) + existing.triggers = wf.get("triggers", []) + existing.updated_at = now + else: + wf_service.create_workflow_config( + app_id=app_id, + nodes=[n.model_dump() for n in result.nodes], + edges=[e.model_dump() for e in result.edges], + variables=[v.model_dump() for v in result.variables], + execution_config=wf.get("execution_config", {}), + features=wf.get("features", {}), + triggers=wf.get("triggers", []), + validate=False, + ) def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str: """生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用""" diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 377f9479..534ab8d0 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -13,7 +13,7 @@ import uuid from typing import Annotated, Any, Dict, List, Optional, Tuple from fastapi import Depends -from sqlalchemy import and_, delete, func, or_, select +from sqlalchemy import and_, delete, func, or_, select, update as sa_update from sqlalchemy.orm import Session from app.core.error_codes import BizCode @@ -401,7 +401,7 @@ class AppService: def _create_workflow_config( self, app_id: uuid.UUID, - data: app_schema.WorkflowConfigCreate, + data, now: datetime.datetime ): workflow_cfg = WorkflowConfig( @@ -411,6 +411,7 @@ class AppService: edges=[edge.model_dump() for edge in data.edges] if data.edges else [], variables=[var.model_dump() for var in data.variables] if data.variables else [], execution_config=data.execution_config.model_dump() if data.execution_config else {}, + features=data.features if data.features else {}, triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], is_active=True, created_at=now, @@ -619,6 +620,28 @@ class AppService: self._validate_app_accessible(app, workspace_id) return app + def get_release_by_id(self, app_id: uuid.UUID, release_id: uuid.UUID) -> AppRelease: + """按发布版本ID获取发布快照 + + Args: + app_id: 应用ID + release_id: 发布版本ID + + Returns: + AppRelease: 发布快照 + + Raises: + BusinessException: 版本不存在或已下线 + """ + from app.repositories.app_repository import get_release_by_id + release = get_release_by_id(self.db, app_id, release_id) + if not release: + raise BusinessException( + f"版本 {release_id} 不存在或已下线", + BizCode.RELEASE_NOT_FOUND, + ) + return release + def create_app( self, *, @@ -678,7 +701,9 @@ class AppService: self._create_multi_agent_config(app.id, data.multi_agent_config, now) if app.type == "workflow" and data.workflow_config: - self._create_workflow_config(app.id, data.workflow_config, now) + from app.schemas.workflow_schema import WorkflowConfigCreate + wf_data = WorkflowConfigCreate(**data.workflow_config) if isinstance(data.workflow_config, dict) else data.workflow_config + self._create_workflow_config(app.id, wf_data, now) self.db.commit() self.db.refresh(app) @@ -757,6 +782,17 @@ class AppService: # 逻辑删除应用 app.is_active = False + + # 更新 app_shares 表中该应用的所有共享记录为失效状态,并更新 updated_at 时间 + stmt = sa_update(AppShare).where( + AppShare.source_app_id == app_id, + AppShare.is_active.is_(True) + ).values( + is_active=False, + updated_at=datetime.datetime.now() + ) + self.db.execute(stmt) + self.db.commit() logger.info( @@ -1347,6 +1383,7 @@ class AppService: variables=cfg.get("variables", []), execution_config=cfg.get("execution_config", {}), triggers=cfg.get("triggers", []), + features=cfg.get("features", {}), is_active=True, created_at=now, updated_at=now, diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index bd7f7496..61744ec7 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -534,6 +534,7 @@ class ConversationService: api_key = api_config.api_key api_base = api_config.api_base is_omni = api_config.is_omni + capability = api_config.capability model_type = config.type llm = RedBearLLM( @@ -542,7 +543,8 @@ class ConversationService: provider=provider, api_key=api_key, base_url=api_base, - is_omni=is_omni + is_omni=is_omni, + capability=capability, ), type=ModelType(model_type) ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 4b503f2b..81457a08 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -458,7 +458,7 @@ class AgentRunService: statement = opening["statement"] suggested_questions = opening["suggested_questions"] - + # 如果有变量,进行替换(仅支持 {{var_name}} 格式) if variables: for var_name, var_value in variables.items(): @@ -595,6 +595,10 @@ class AgentRunService: max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, + deep_thinking=effective_params.get("deep_thinking", False), + thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), + json_output=effective_params.get("json_output", False), + capability=api_key_config.get("capability", []), ) # 5. 处理会话ID(创建或验证),新会话时写入开场白 @@ -637,7 +641,14 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - + # 为需要运行时上下文的工具注入上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 7. 知识库检索 context = None @@ -689,7 +700,8 @@ class AgentRunService: "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 - }) + }), + "reasoning_content": result.get("reasoning_content") }, files=files, processed_files=processed_files, @@ -701,6 +713,7 @@ class AgentRunService: response = { "message": result["content"], + "reasoning_content": result.get("reasoning_content"), "conversation_id": conversation_id, "usage": result.get("usage", { "prompt_tokens": 0, @@ -838,7 +851,11 @@ class AgentRunService: max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, - streaming=True + streaming=True, + deep_thinking=effective_params.get("deep_thinking", False), + thinking_budget_tokens=effective_params.get("thinking_budget_tokens"), + json_output=effective_params.get("json_output", False), + capability=api_key_config.get("capability", []), ) # 5. 处理会话ID(创建或验证),新会话时写入开场白 @@ -882,7 +899,14 @@ class AgentRunService: multimodal_service = MultimodalService(self.db, model_info) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") - + # 为需要运行时上下文的工具注入上下文 + for t in tools: + if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'): + t.tool_instance.set_runtime_context( + user_id=user_id or "anonymous", + conversation_id=str(conversation_id) if conversation_id else None, + uploaded_files=processed_files or [] + ) # 7. 知识库检索 context = None @@ -898,6 +922,7 @@ class AgentRunService: # 9. 流式调用 Agent(支持多模态),同时并行启动 TTS full_content = "" + full_reasoning = "" total_tokens = 0 # 启动流式 TTS(文本边输出边合成) @@ -916,6 +941,9 @@ class AgentRunService: ): if isinstance(chunk, int): total_tokens = chunk + elif isinstance(chunk, dict) and chunk.get("type") == "reasoning": + full_reasoning += chunk["content"] + yield self._format_sse_event("reasoning", {"content": chunk["content"]}) else: full_content += chunk yield self._format_sse_event("message", {"content": chunk}) @@ -944,7 +972,8 @@ class AgentRunService: app_id=agent_config.app_id, user_id=user_id, meta_data={ - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}, + "reasoning_content": full_reasoning or None }, files=files, processed_files=processed_files, @@ -1272,10 +1301,30 @@ class AgentRunService: "history_files": {} } if files: + from app.models.file_metadata_model import FileMetadata + local_ids = [f.upload_file_id for f in files + if f.transfer_method.value == "local_file" and f.upload_file_id + and (not f.name or not f.size)] + meta_map = {} + if local_ids: + rows = self.db.query(FileMetadata).filter( + FileMetadata.id.in_(local_ids), + FileMetadata.status == "completed" + ).all() + meta_map = {str(r.id): r for r in rows} for f in files: + name, size = f.name, f.size + if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size): + meta = meta_map.get(str(f.upload_file_id)) + if meta: + name = name or meta.file_name + size = size or meta.file_size human_meta["files"].append({ "type": f.type, - "url": f.url + "url": f.url, + "file_type": f.file_type, + "name": name, + "size": size }) # 保存 history_files,包含 provider 和 is_omni 信息 @@ -1665,7 +1714,7 @@ class AgentRunService: """从 text_queue 取文本按句子切分后喂给 synthesizer""" import re buf = "" - sentence_end = re.compile(r'[\u3002\uff01\uff1f\.!?\n]') + sentence_end = re.compile(r'[\u3002\uff01\uff1f.!?\n]') while True: chunk = await text_queue.get() if chunk is None: @@ -1894,6 +1943,7 @@ class AgentRunService: "conversation_id": result['conversation_id'], "parameters_used": model_info["parameters"], "message": result.get("message"), + "reasoning_content": result.get("reasoning_content"), "usage": usage, "elapsed_time": elapsed, "tokens_per_second": ( @@ -2012,7 +2062,7 @@ class AgentRunService: # 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段 return None - def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig: + def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> tuple[AgentConfig, Any]: """创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters) Args: @@ -2110,6 +2160,7 @@ class AgentRunService: start_time = time.time() full_content = "" + full_reasoning = "" returned_conversation_id = model_conversation_id audio_url = None audio_status = None @@ -2168,6 +2219,18 @@ class AgentRunService: "content": chunk })) + # 转发深度思考事件(带模型标识) + if event_type == "reasoning" and event_data: + reasoning_chunk = event_data.get("content", "") + full_reasoning += reasoning_chunk + await event_queue.put(self._format_sse_event("model_reasoning", { + "model_index": idx, + "model_config_id": model_config_id, + "label": model_label, + "conversation_id": returned_conversation_id, + "content": event_data.get("content", "") + })) + # 从 end 事件中提取 features 输出字段 if event_type == "end" and event_data: audio_url = event_data.get("audio_url") @@ -2199,6 +2262,7 @@ class AgentRunService: "conversation_id": returned_conversation_id, "parameters_used": model_info["parameters"], "message": full_content, + "reasoning_content": full_reasoning or None, "elapsed_time": elapsed, "audio_url": audio_url, "audio_status": audio_status, @@ -2351,6 +2415,7 @@ class AgentRunService: "label": r["label"], "conversation_id": r.get("conversation_id"), "message": r.get("message"), + "reasoning_content": r.get("reasoning_content"), "elapsed_time": r.get("elapsed_time", 0), "audio_url": r.get("audio_url"), "audio_status": r.get("audio_status"), diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index c226348e..9a215cd6 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -679,9 +679,9 @@ class EmotionAnalyticsService: # 查询用户的实体和标签 query = """ - MATCH (e:Entity) + MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id - RETURN e.name as name, e.type as type + RETURN e.name as name, e.entity_type as type ORDER BY e.created_at DESC LIMIT 20 """ diff --git a/api/app/services/implicit_memory_service.py b/api/app/services/implicit_memory_service.py index 4bd11deb..7a186f33 100644 --- a/api/app/services/implicit_memory_service.py +++ b/api/app/services/implicit_memory_service.py @@ -34,6 +34,7 @@ from app.schemas.implicit_memory_schema import ( UserMemorySummary, ) from app.schemas.memory_config_schema import MemoryConfig +from app.services.memory_base_service import MIN_MEMORY_SUMMARY_COUNT from sqlalchemy.orm import Session logger = logging.getLogger(__name__) @@ -379,12 +380,59 @@ class ImplicitMemoryService: raise + def _build_empty_profile(self) -> dict: + """构建 MemorySummary 不足时返回的固定空白画像数据""" + now_ms = int(datetime.utcnow().timestamp() * 1000) + insufficient = "Insufficient data for analysis" + + def _empty_dimension(name: str) -> dict: + return { + "evidence": [insufficient], + "reasoning": f"No clear evidence found for {name} dimension", + "percentage": 0.0, + "dimension_name": name, + "confidence_level": 20, + } + + def _empty_category(name: str) -> dict: + return { + "evidence": [insufficient], + "percentage": 25.0, + "category_name": name, + "trending_direction": None, + } + + return { + "habits": [], + "portrait": { + "aesthetic": _empty_dimension("aesthetic"), + "creativity": _empty_dimension("creativity"), + "literature": _empty_dimension("literature"), + "technology": _empty_dimension("technology"), + "historical_trends": None, + "analysis_timestamp": now_ms, + "total_summaries_analyzed": 0, + }, + "preferences": [], + "interest_areas": { + "art": _empty_category("art"), + "tech": _empty_category("tech"), + "music": _empty_category("music"), + "lifestyle": _empty_category("lifestyle"), + "analysis_timestamp": now_ms, + "total_summaries_analyzed": 0, + }, + } + async def generate_complete_profile( self, user_id: str ) -> dict: """生成完整的用户画像(包含所有4个模块) + 需要该用户的 MemorySummary 节点数量 >= 5 才会真正调用 LLM 生成画像, + 否则返回固定的空白画像数据。 + Args: user_id: 用户ID @@ -394,6 +442,16 @@ class ImplicitMemoryService: logger.info(f"生成完整用户画像: user={user_id}") try: + # 前置检查:查询该用户有效的 MemorySummary 节点数量(排除孤立节点) + from app.services.memory_base_service import MemoryBaseService + base_service = MemoryBaseService() + memory_summary_count = await base_service.get_valid_memory_summary_count(user_id) + logger.info(f"用户 MemorySummary 节点数量: {memory_summary_count} (user={user_id})") + + if memory_summary_count < MIN_MEMORY_SUMMARY_COUNT: + logger.info(f"MemorySummary 数量不足 {MIN_MEMORY_SUMMARY_COUNT}(当前 {memory_summary_count}),返回空白画像: user={user_id}") + return self._build_empty_profile() + # 并行调用4个分析方法 preferences, portrait, interest_areas, habits = await asyncio.gather( self.get_preference_tags(user_id=user_id), diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 02895d6b..bd90eee9 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -415,8 +415,11 @@ class LLMRouter: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - temperature=0.3, - max_tokens=500 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.3, + "max_tokens": 500 + } ) logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}") diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index b0f43b51..dfb3c2da 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -393,6 +393,7 @@ class MasterAgentRouter: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, + capability=api_key_config.capability, extra_params = extra_params ) @@ -403,6 +404,17 @@ class MasterAgentRouter: response = await llm.ainvoke(prompt) ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取 token 消耗 + self._last_routing_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + self._last_routing_tokens = token_usage.get("total_tokens", 0) + logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}") + # 提取响应内容 if hasattr(response, 'content'): return response.content diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index c27a75be..b12bb48a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -462,11 +462,6 @@ class MemoryAgentService: logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") - # 导入审计日志记录器 - - - - config_load_start = time.time() try: # Use a separate database session to avoid transaction failures @@ -507,10 +502,13 @@ class MemoryAgentService: async with make_read_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "end_user_id": end_user_id - , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, - "memory_config": memory_config} + initial_state = { + "messages": [HumanMessage(content=message)], + "search_switch": search_switch, + "end_user_id": end_user_id + , "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} # 获取节点更新信息 _intermediate_outputs = [] summary = '' @@ -522,7 +520,7 @@ class MemoryAgentService: for node_name, node_data in update_event.items(): # if 'save_neo4j' == node_name: # massages = node_data - print(f"处理节点: {node_name}") + logger.info(f"处理节点: {node_name}") # 处理不同Summary节点的返回结构 if 'Summary' in node_name: @@ -549,6 +547,11 @@ class MemoryAgentService: if retrieve_node and retrieve_node != [] and retrieve_node != {}: _intermediate_outputs.extend(retrieve_node) + # Perceptual_Retrieve 节点 + perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None) + if perceptual_node and perceptual_node != [] and perceptual_node != {}: + _intermediate_outputs.append(perceptual_node) + # Verify 节点 verify_n = node_data.get('verify', {}).get('_intermediate', None) if verify_n and verify_n != [] and verify_n != {}: diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index bc647752..e615af8b 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -265,12 +265,50 @@ async def Translation_English(modid, text, fields=None): # 其他类型(数字、布尔值、None等):原样返回 else: return text +# 隐性记忆画像生成所需的最低 MemorySummary 节点数量 +MIN_MEMORY_SUMMARY_COUNT = 5 + + class MemoryBaseService: """记忆服务基类,提供共享的辅助方法""" def __init__(self): self.neo4j_connector = Neo4jConnector() + async def get_valid_memory_summary_count( + self, + end_user_id: str + ) -> int: + """获取用户有效的 MemorySummary 节点数量(排除孤立节点)。 + + 只统计存在 DERIVED_FROM_STATEMENT 关系的 MemorySummary 节点。 + + Args: + end_user_id: 终端用户ID + + Returns: + 有效 MemorySummary 节点数量 + """ + try: + query = """ + MATCH (n:MemorySummary)-[:DERIVED_FROM_STATEMENT]->(:Statement) + WHERE n.end_user_id = $end_user_id + RETURN count(DISTINCT n) as count + """ + result = await self.neo4j_connector.execute_query( + query, end_user_id=end_user_id + ) + count = result[0]["count"] if result and len(result) > 0 else 0 + logger.debug( + f"有效 MemorySummary 节点数量: {count} (end_user_id={end_user_id})" + ) + return count + except Exception as e: + logger.error( + f"获取有效 MemorySummary 数量失败: {str(e)}", exc_info=True + ) + return 0 + @staticmethod def parse_timestamp(timestamp_value) -> Optional[int]: """ diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 791a6fe8..a01b1d00 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -353,15 +353,13 @@ async def get_workspace_total_memory_count( "details": [] } - # 2. 对每个 host_id 调用 search_all 获取 total + # 2. 使用 search_all_batch 批量查询所有宿主的记忆数量 from app.services import memory_storage_service - total_count = 0 - details = [] - # 如果提供了 end_user_id,只查询该用户 if end_user_id: - search_result = await memory_storage_service.search_all(end_user_id=end_user_id) + batch_result = await memory_storage_service.search_all_batch([end_user_id]) + count = batch_result.get(end_user_id, 0) # 查询用户名称 from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(db) @@ -369,42 +367,31 @@ async def get_workspace_total_memory_count( user_name = end_user.other_name if end_user else None return { - "total_memory_count": search_result.get("total", 0), + "total_memory_count": count, "host_count": 1, "details": [{ "end_user_id": end_user_id, - "count": search_result.get("total", 0), + "count": count, "name": user_name }] } - for host in hosts: - try: - end_user_id_str = str(host.id) - - search_result = await memory_storage_service.search_all( - end_user_id=end_user_id_str - ) - - host_total = search_result.get("total", 0) - total_count += host_total - - details.append({ - "end_user_id": end_user_id_str, - "count": host_total, - "name": host.other_name # 使用 other_name 字段 - }) - - business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}") - - except Exception as e: - business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}") - # 失败的 host 记为 0 - details.append({ - "end_user_id": str(host.id), - "count": 0, - "name": host.other_name # 使用 other_name 字段 - }) + # 批量查询所有宿主记忆数量(一次 Neo4j 查询) + end_user_ids = [str(host.id) for host in hosts] + batch_result = await memory_storage_service.search_all_batch(end_user_ids) + + # 构建 host name 映射 + host_name_map = {str(host.id): host.other_name for host in hosts} + + total_count = sum(batch_result.values()) + details = [ + { + "end_user_id": uid, + "count": batch_result.get(uid, 0), + "name": host_name_map.get(uid) + } + for uid in end_user_ids + ] result = { "total_memory_count": total_count, @@ -519,6 +506,180 @@ def get_rag_user_kb_total_chunk( business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}") raise +def get_dashboard_yesterday_changes( + db: Session, + workspace_id: uuid.UUID, + storage_type: str, + today_data: dict +) -> dict: + """ + 计算各指标相比昨天的变化百分比。 + + - total_app_change / total_knowledge_change:只看活跃记录, + 百分比 = (截止今日活跃总量 - 截止昨日活跃总量) / 截止昨日活跃总量 + - total_memory_change / total_api_call_change: + 百分比 = (今日总量 - 昨日总量) / 昨日总量 + + 昨日总量为 0 时返回 None。返回值为浮点数,例如 0.5 表示增长 50%。 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + storage_type: 存储类型 'neo4j' | 'rag' + today_data: 当前数据,包含 total_memory, total_app, total_knowledge, total_api_call + + Returns: + { + "total_memory_change": float | None, + "total_app_change": float | None, + "total_knowledge_change": float | None, + "total_api_call_change": float | None + } + """ + from datetime import datetime + from sqlalchemy import func + from app.models.api_key_model import ApiKey, ApiKeyLog + from app.models.knowledge_model import Knowledge + from app.models.app_model import App + from app.models.appshare_model import AppShare + + business_logger.info(f"计算昨日对比百分比: workspace_id={workspace_id}, storage_type={storage_type}") + + now_local = datetime.now() + today_start = now_local.replace(hour=0, minute=0, second=0, microsecond=0) + + changes = { + "total_memory_change": None, + "total_app_change": None, + "total_knowledge_change": None, + "total_api_call_change": None, + } + + def _calc_percentage(today_val, yesterday_val): + """计算百分比,昨日为0时返回None""" + if yesterday_val is None or yesterday_val == 0: + return None + return round((today_val - yesterday_val) / yesterday_val, 4) + + # --- total_api_call_change: (截止今日累计总数 - 截止昨日累计总数) / 截止昨日累计总数 --- + try: + api_key_ids = [ + row[0] for row in db.query(ApiKey.id).filter( + ApiKey.workspace_id == workspace_id + ).all() + ] + if api_key_ids: + # 截止今日的累计调用总数 + total_api_until_now = db.query(func.count(ApiKeyLog.id)).filter( + ApiKeyLog.api_key_id.in_(api_key_ids), + ApiKeyLog.created_at < now_local + ).scalar() or 0 + # 截止昨日的累计调用总数(today_start 即昨日结束) + total_api_until_yesterday = db.query(func.count(ApiKeyLog.id)).filter( + ApiKeyLog.api_key_id.in_(api_key_ids), + ApiKeyLog.created_at < today_start + ).scalar() or 0 + changes["total_api_call_change"] = _calc_percentage(total_api_until_now, total_api_until_yesterday) + else: + changes["total_api_call_change"] = None + except Exception as e: + business_logger.warning(f"计算API调用昨日对比失败: {str(e)}") + + # --- total_knowledge_change: 只看活跃(status=1)且为顶层知识库(parent_id=workspace_id),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 --- + try: + # 截止今日的活跃知识库总量(当前 status=1,parent_id=workspace_id) + today_knowledge = db.query(func.count(Knowledge.id)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.parent_id == Knowledge.workspace_id + ).scalar() or 0 + # 截止昨日的活跃知识库总量(昨日之前创建的、当前仍 status=1,parent_id=workspace_id) + yesterday_knowledge = db.query(func.count(Knowledge.id)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.parent_id == Knowledge.workspace_id, + Knowledge.created_at < today_start + ).scalar() or 0 + + changes["total_knowledge_change"] = _calc_percentage(today_knowledge, yesterday_knowledge) + except Exception as e: + business_logger.warning(f"计算知识库昨日对比失败: {str(e)}") + + # --- total_app_change: 只看活跃(is_active=True),百分比 = (今日活跃总量 - 昨日活跃总量) / 昨日活跃总量 --- + try: + # === 自有app === + today_own_apps = db.query(func.count(App.id)).filter( + App.workspace_id == workspace_id, + App.is_active == True + ).scalar() or 0 + yesterday_own_apps = db.query(func.count(App.id)).filter( + App.workspace_id == workspace_id, + App.is_active == True, + App.created_at < today_start + ).scalar() or 0 + + # === 被分享app === + today_shared_apps = db.query(func.count(AppShare.id)).filter( + AppShare.target_workspace_id == workspace_id, + AppShare.is_active == True + ).scalar() or 0 + yesterday_shared_apps = db.query(func.count(AppShare.id)).filter( + AppShare.target_workspace_id == workspace_id, + AppShare.is_active == True, + AppShare.created_at < today_start + ).scalar() or 0 + + today_total_app = today_own_apps + today_shared_apps + yesterday_total_app = yesterday_own_apps + yesterday_shared_apps + + changes["total_app_change"] = _calc_percentage(today_total_app, yesterday_total_app) + except Exception as e: + business_logger.warning(f"计算应用数量昨日对比失败: {str(e)}") + + # --- total_memory_change: (今日总量 - 昨日总量) / 昨日总量 --- + try: + today_memory = today_data.get("total_memory") + if today_memory is None: + changes["total_memory_change"] = None + elif storage_type == "neo4j": + last_record = db.query(MemoryIncrement).filter( + MemoryIncrement.workspace_id == workspace_id, + MemoryIncrement.created_at < today_start + ).order_by(desc(MemoryIncrement.created_at)).first() + if last_record is None or last_record.total_num == 0: + changes["total_memory_change"] = None + else: + changes["total_memory_change"] = _calc_percentage(today_memory, last_record.total_num) + elif storage_type == "rag": + from app.models.document_model import Document + from app.models.end_user_model import EndUser as _EndUser + from app.models.app_model import App as _App + + end_user_ids = [ + str(eid) for (eid,) in db.query(_EndUser.id) + .join(_App, _EndUser.app_id == _App.id) + .filter(_App.workspace_id == workspace_id) + .all() + ] + if not end_user_ids: + changes["total_memory_change"] = None + else: + file_names = [f"{uid}.txt" for uid in end_user_ids] + yesterday_chunk = int(db.query(func.sum(Document.chunk_num)).filter( + Document.file_name.in_(file_names), + Document.created_at < today_start + ).scalar() or 0) + if yesterday_chunk == 0: + changes["total_memory_change"] = None + else: + changes["total_memory_change"] = _calc_percentage(today_memory, yesterday_chunk) + except Exception as e: + business_logger.warning(f"计算记忆总量昨日对比失败: {str(e)}") + + business_logger.info(f"昨日对比百分比计算完成: {changes}") + return changes + + def get_current_user_total_chunk( end_user_id: str, db: Session, @@ -642,7 +803,6 @@ def get_rag_content( "page": { "page": page, "pagesize": pagesize, - "total": 0, "hasnext": False, }, "items": [] @@ -736,13 +896,12 @@ def get_rag_content( "page": { "page": page, "pagesize": pagesize, - "total": global_total, "hasnext": offset_end < global_total, }, "items": conversations } - business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话") + business_logger.info(f"成功获取RAG内容: page={page}, 返回={len(conversations)} 条对话") return result except Exception as e: @@ -881,4 +1040,60 @@ async def generate_rag_profile( "tags_count": len(tags), "personas_count": len(personas), "insight_generated": bool(insight_sections.get("memory_insight")), - } \ No newline at end of file + } + + +def get_dashboard_common_stats(db: Session, workspace_id) -> dict: + """ + 获取 dashboard 中 neo4j/rag 分支共享的统计数据: + total_app、total_knowledge、total_api_call + + Returns: + dict: {"total_app": int, "total_knowledge": int, "total_api_call": int} + """ + result = {"total_app": 0, "total_knowledge": 0, "total_api_call": 0} + + # total_app: 统计当前空间下的所有app数量(包含自有 + 被分享给本工作空间的app) + try: + from app.services import app_service as _app_svc + _, total_app = _app_svc.AppService(db).list_apps( + workspace_id=workspace_id, include_shared=True, pagesize=1 + ) + result["total_app"] = total_app + except Exception as e: + business_logger.warning(f"获取应用数量失败: {e}") + + # total_knowledge: 统计顶层知识库(parent_id = workspace_id) + try: + from sqlalchemy import func as _func + from app.models.knowledge_model import Knowledge as _Knowledge + total_knowledge = db.query(_func.count(_Knowledge.id)).filter( + _Knowledge.workspace_id == workspace_id, + _Knowledge.status == 1, + _Knowledge.parent_id == _Knowledge.workspace_id + ).scalar() or 0 + result["total_knowledge"] = total_knowledge + except Exception as e: + business_logger.warning(f"获取知识库数量失败: {e}") + + # total_api_call: 截止当前的历史累计调用总数 + try: + from sqlalchemy import func as _api_func + from app.models.api_key_model import ApiKey as _ApiKey, ApiKeyLog as _ApiKeyLog + + _api_key_ids = [ + row[0] for row in db.query(_ApiKey.id).filter( + _ApiKey.workspace_id == workspace_id + ).all() + ] + if _api_key_ids: + total_api_calls = db.query(_api_func.count(_ApiKeyLog.id)).filter( + _ApiKeyLog.api_key_id.in_(_api_key_ids) + ).scalar() or 0 + else: + total_api_calls = 0 + result["total_api_call"] = total_api_calls + except Exception as e: + business_logger.warning(f"获取API调用统计失败: {e}") + + return result diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 7cf94a1a..8fa9c9bf 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -232,7 +232,8 @@ class MemoryPerceptualService: provider=model_config.provider, api_key=model_config.api_key, base_url=model_config.api_base, - is_omni=model_config.is_omni + is_omni=model_config.is_omni, + capability=model_config.capability, ) ) return llm, model_config diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index b3a66734..132370b6 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -613,37 +613,6 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: return data -async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - MemoryConfigRepository.SEARCH_FOR_ALL, - end_user_id=end_user_id, - ) - - # 检查结果是否为空或长度不足 - if not result or len(result) < 4: - data = { - "total": 0, - "counts": { - "dialogue": 0, - "chunk": 0, - "statement": 0, - "entity": 0, - }, - } - return data - - data = { - "total": result[-1]["Count"], - "counts": { - "dialogue": result[0]["Count"], - "chunk": result[1]["Count"], - "statement": result[2]["Count"], - "entity": result[3]["Count"], - }, - } - return data - - async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]: """统一知识库类型分布接口。 diff --git a/api/app/services/model_parameter_merger.py b/api/app/services/model_parameter_merger.py index 262e3d49..6911a9d5 100644 --- a/api/app/services/model_parameter_merger.py +++ b/api/app/services/model_parameter_merger.py @@ -45,12 +45,21 @@ class ModelParameterMerger: "frequency_penalty": 0.0, "presence_penalty": 0.0, "n": 1, - "stop": None + "stop": None, + "deep_thinking": False, + "thinking_budget_tokens": None, + "json_output": False } # 合并参数:默认值 -> 模型配置 -> Agent 配置 merged = default_params.copy() + # Pydantic 对象转为 dict + if model_config_params and hasattr(model_config_params, 'model_dump'): + model_config_params = model_config_params.model_dump() + if agent_config_params and hasattr(agent_config_params, 'model_dump'): + agent_config_params = agent_config_params.model_dump() + # 应用模型配置参数 if model_config_params: for key in default_params: diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index c9266667..8807020b 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -85,15 +85,16 @@ class ModelConfigService: @staticmethod async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello", - is_omni: bool = False + db: Session, + *, + model_name: str, + provider: str, + api_key: str, + api_base: Optional[str] = None, + model_type: str = "llm", + test_message: str = "Hello", + is_omni: bool = False, + capability: Optional[list] = None ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -124,8 +125,11 @@ class ModelConfigService: api_key=api_key, base_url=api_base, is_omni=is_omni, - temperature=0.7, - max_tokens=100 + capability=capability, + extra_params={ + "temperature": 0.7, + "max_tokens": 100 + } ) # 根据模型类型选择不同的验证方式 @@ -320,7 +324,8 @@ class ModelConfigService: api_base=api_key_data.api_base, model_type=model_data.type, test_message="Hello", - is_omni=model_data.is_omni + is_omni=model_data.is_omni, + capability=model_data.capability ) if not validation_result["valid"]: raise BusinessException( @@ -590,7 +595,8 @@ class ModelApiKeyService: api_base=data.api_base, model_type=model_config.type, test_message="Hello", - is_omni=data.is_omni + is_omni=data.is_omni, + capability=model_config.capability ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 @@ -675,7 +681,8 @@ class ModelApiKeyService: api_base=api_key_data.api_base, model_type=model_config.type, test_message="Hello", - is_omni=api_key_data.is_omni + is_omni=api_key_data.is_omni, + capability=model_config.capability ) if not validation_result["valid"]: raise BusinessException( @@ -707,7 +714,8 @@ class ModelApiKeyService: api_base=api_key_data.api_base or existing_api_key.api_base, model_type=model_config.type, test_message="Hello", - is_omni=model_config.is_omni + is_omni=model_config.is_omni, + capability=model_config.capability ) if not validation_result["valid"]: raise BusinessException( @@ -723,10 +731,21 @@ class ModelApiKeyService: @staticmethod def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool: """删除API Key""" - if not ModelApiKeyRepository.get_by_id(db, api_key_id): + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if not api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) + model_config_ids = [mc.id for mc in api_key.model_configs] + success = ModelApiKeyRepository.delete(db, api_key_id) + + for model_config_id in model_config_ids: + model_config = ModelConfigRepository.get_by_id(db, model_config_id) + if model_config: + has_active_key = any(key.is_active for key in model_config.api_keys) + if not has_active_key and model_config.is_active: + model_config.is_active = False + db.commit() return success diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 60a3b5b8..d30dc822 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -287,6 +287,11 @@ class MultiAgentOrchestrator: sub_conversation_id = None total_tokens = 0 + # 累加 Master Agent 路由决策消耗的 token + total_tokens += task_analysis.get("routing_tokens", 0) + # 累加 Master Agent 整合消耗的 token + total_tokens += getattr(self, '_last_merge_tokens', 0) + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") # 提取 token 信息 @@ -358,12 +363,16 @@ class MultiAgentOrchestrator: variables=variables ) + # 获取路由决策消耗的 token + routing_tokens = getattr(self.router, '_last_routing_tokens', 0) + logger.info( "Master Agent 分析完成", extra={ "selected_agent": routing_decision.get("selected_agent_id"), "confidence": routing_decision.get("confidence"), - "strategy": routing_decision.get("strategy") + "strategy": routing_decision.get("strategy"), + "routing_tokens": routing_tokens } ) @@ -372,7 +381,8 @@ class MultiAgentOrchestrator: "variables": variables or {}, "sub_agents": self.config.sub_agents, "initial_context": variables or {}, - "routing_decision": routing_decision + "routing_decision": routing_decision, + "routing_tokens": routing_tokens } async def _execute_sequential( @@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator: # 5. 流式执行子 Agent sub_conversation_id = None + # Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层 + routing_tokens = task_analysis.get("routing_tokens", 0) + if routing_tokens > 0: + yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens}) + async for event in self._execute_sub_agent_stream( agent_data["config"], message, @@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator: except: pass + # 直接透传所有事件(包括 sub_usage),累加统一由上层处理 yield event # 6. 如果有会话 ID,发送一个包含它的事件 @@ -2600,8 +2616,11 @@ class MultiAgentOrchestrator: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - temperature=0.7, # 整合任务使用中等温度 - max_tokens=2000 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.7, # 整合任务使用中等温度 + "max_tokens": 2000 + } ) # 创建 LLM 实例 @@ -2612,6 +2631,17 @@ class MultiAgentOrchestrator: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取整合消耗的 token + merge_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + merge_tokens = token_usage.get("total_tokens", 0) + self._last_merge_tokens = merge_tokens + # 提取响应内容 if hasattr(response, 'content'): merged_response = response.content @@ -2621,7 +2651,8 @@ class MultiAgentOrchestrator: logger.info( "Master Agent 整合完成", extra={ - "merged_length": len(merged_response) + "merged_length": len(merged_response), + "merge_tokens": merge_tokens } ) @@ -2766,9 +2797,12 @@ class MultiAgentOrchestrator: api_key=api_key_config.api_key, base_url=api_key_config.api_base, is_omni=api_key_config.is_omni, - temperature=0.7, - max_tokens=2000, - extra_params={"streaming": True} # 启用流式输出 + capability=api_key_config.capability, + extra_params={ + "temperature": 0.7, + "max_tokens": 2000, + "streaming": True # 启用流式输出 + } ) # 创建 LLM 实例 diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 184220a8..1686a164 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -185,7 +185,8 @@ class PromptOptimizerService: provider=api_config.provider, api_key=api_config.api_key, base_url=api_config.api_base, - is_omni=api_config.is_omni + is_omni=api_config.is_omni, + capability=api_config.capability, ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') @@ -226,10 +227,20 @@ class PromptOptimizerService: content = getattr(chunk, "content", chunk) if not content: continue - buffer += content + if isinstance(content, str): + buffer += content + elif isinstance(content, list): + for _ in content: + buffer += _["text"] + else: + logger.error(f"Unsupported content type - {content}") + raise Exception("Unsupported content type") cache = buffer[:-20] + last_idx = 19 + while cache and cache[-1] == '\\' and last_idx > 0: + cache = buffer[:-last_idx] + last_idx -= 1 - # 尝试找到 "prompt": " 开始位置 if prompt_finished: continue @@ -271,7 +282,7 @@ class PromptOptimizerService: def parser_prompt_variables(prompt: str): try: pattern = r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}' - matches = re.findall(pattern, prompt) + matches = re.findall(pattern, str(prompt)) variables = list(set(matches)) return variables except Exception as e: diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index c74604a5..37956d77 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -248,7 +248,10 @@ class SharedChatService: max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, - + deep_thinking=model_parameters.get("deep_thinking", False), + thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), + capability=api_key_obj.capability, ) # 加载历史消息 @@ -450,7 +453,11 @@ class SharedChatService: max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, - streaming=True + streaming=True, + deep_thinking=model_parameters.get("deep_thinking", False), + thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"), + json_output=model_parameters.get("json_output", False), + capability=api_key_obj.capability or [], ) # 加载历史消息 @@ -479,6 +486,8 @@ class SharedChatService: ): if isinstance(chunk, int): total_tokens = chunk + elif isinstance(chunk, dict) and chunk.get("type") == "reasoning": + yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n" else: full_content += chunk # 发送消息块事件 diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 089f0ec5..9a59cd81 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -34,7 +34,8 @@ BUILTIN_TOOLS = { "JsonTool": "app.core.tools.builtin.json_tool", "BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool", "MinerUTool": "app.core.tools.builtin.mineru_tool", - "TextInTool": "app.core.tools.builtin.textin_tool" + "TextInTool": "app.core.tools.builtin.textin_tool", + "OpenClawTool": "app.core.tools.builtin.openclaw_tool", } @@ -340,18 +341,18 @@ class ToolService: return {"success": False, "message": f"测试失败: {str(e)}"} def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID): - """确保内置工具已初始化""" - existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id) - - if existing: + """确保内置工具已初始化(支持增量补充新工具)""" + builtin_config = self._load_builtin_config() + if not builtin_config: return - # 从配置文件加载内置工具定义 - builtin_config = self._load_builtin_config() + existing_classes = self.builtin_repo.get_existing_tool_classes(self.db, tenant_id) + added = False for tool_key, tool_info in builtin_config.items(): + if tool_info['tool_class'] in existing_classes: + continue try: - # 创建工具配置 initial_status = self._determine_initial_status(tool_info) tool_config = ToolConfig( name=tool_info['name'], @@ -367,7 +368,6 @@ class ToolService: self.db.add(tool_config) self.db.flush() - # 创建内置工具配置 builtin_config_obj = BuiltinToolConfig( id=tool_config.id, tool_class=tool_info['tool_class'], @@ -375,12 +375,14 @@ class ToolService: requires_config=tool_info.get('requires_config', False) ) self.db.add(builtin_config_obj) + added = True except Exception as e: logger.error(f"初始化内置工具失败: {tool_key}, {e}") - self.db.commit() - logger.info(f"租户 {tenant_id} 内置工具初始化完成") + if added: + self.db.commit() + logger.info(f"租户 {tenant_id} 内置工具增量初始化完成") async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]: """获取工具的所有方法 @@ -458,6 +460,9 @@ class ToolService: # 对于json_tool,根据操作类型返回相关参数 elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': return self._get_json_tool_params(operation) + # 对于openclaw_tool,根据操作类型返回不同描述的参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'openclaw_tool': + return self._get_openclaw_tool_params(operation) # 其他工具的默认处理:返回除operation外的所有参数 return [{ @@ -574,6 +579,29 @@ class ToolService: "default": "Asia/Shanghai" } ] + elif operation == "datetime_to_timestamp": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串,如:2026-04-07 10:30:25)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] else: # 默认返回所有参数(除了operation) return [ @@ -687,6 +715,65 @@ class ToolService: return base_params + @staticmethod + def _get_openclaw_tool_params(operation: str) -> List[Dict[str, Any]]: + """获取 openclaw_tool 特定操作的参数""" + if operation == "print_task": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型", + "required": False + } + ] + elif operation == "device_query": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的设备查询指令", + "required": True + } + ] + elif operation == "image_understand": + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "要分析的图片 URL 或 base64 data URI", + "required": False + } + ] + else: + # general 及其他 + return [ + { + "name": "message", + "type": "string", + "description": "发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求", + "required": True + }, + { + "name": "image_url", + "type": "string", + "description": "可选,附带的图片 URL 或 base64 data URI", + "required": False + } + ] + async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法""" custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index ab51d922..fdc27115 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.core.logging_config import get_logger +from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository @@ -21,7 +22,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping -from app.services.memory_base_service import MemoryBaseService +from app.services.memory_base_service import MemoryBaseService, MIN_MEMORY_SUMMARY_COUNT from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -398,12 +399,16 @@ class UserMemoryService: } # 构建响应数据(转换时间为毫秒时间戳) + # 将 meta_data 中的 profile、knowledge_tags、behavioral_hints 平铺到顶层 + meta = end_user_info_record.meta_data or {} response_data = { "end_user_info_id": str(end_user_info_record.id), "end_user_id": str(end_user_info_record.end_user_id), "other_name": end_user_info_record.other_name, "aliases": end_user_info_record.aliases, - "meta_data": end_user_info_record.meta_data, + "profile": meta.get("profile"), + "knowledge_tags": meta.get("knowledge_tags"), + "behavioral_hints": meta.get("behavioral_hints"), "created_at": datetime_to_timestamp(end_user_info_record.created_at), "updated_at": datetime_to_timestamp(end_user_info_record.updated_at) } @@ -473,7 +478,7 @@ class UserMemoryService: allowed_fields = {'other_name', 'aliases', 'meta_data'} # 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中 - _user_placeholder_names = {'用户', '我', 'User', 'I'} + _user_placeholder_names = _USER_PLACEHOLDER_NAMES # 过滤 other_name:不允许设置为占位名称 if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names: @@ -1500,7 +1505,7 @@ async def analytics_memory_types( 2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取) 3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量 4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) - 5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一 + 5. 隐性记忆 (IMPLICIT_MEMORY) = MemorySummary 节点数量(需 >= MIN_MEMORY_SUMMARY_COUNT 才显示,否则为 0) 6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) 7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) 8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取) @@ -1557,23 +1562,15 @@ async def analytics_memory_types( logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}") work_count = 0 - # 获取隐性记忆数量(基于 Statement 节点数量的三分之一) + # 获取隐性记忆数量(基于有关联关系的 MemorySummary 节点数量,需 >= MIN_MEMORY_SUMMARY_COUNT 才计入) implicit_count = 0 if end_user_id: try: - # 查询 Statement 节点数量 - query = """ - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN count(n) as count - """ - result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) - statement_count = result[0]["count"] if result and len(result) > 0 else 0 - # 取三分之一作为隐性记忆数量 - implicit_count = round(statement_count / 3) - logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})") + memory_summary_count = await base_service.get_valid_memory_summary_count(end_user_id) + implicit_count = memory_summary_count if memory_summary_count >= MIN_MEMORY_SUMMARY_COUNT else 0 + logger.debug(f"隐性记忆数量(有效MemorySummary节点数): {implicit_count} (有效MemorySummary总数={memory_summary_count}, end_user_id={end_user_id})") except Exception as e: - logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}") + logger.warning(f"获取MemorySummary数量失败,隐性记忆数量设为0: {str(e)}") implicit_count = 0 # 原有的基于行为习惯的统计方式(已注释) @@ -1639,7 +1636,7 @@ async def analytics_memory_types( "WORKING_MEMORY": work_count, # 工作记忆(基于会话数量) "SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量) "EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆) - "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3) + "IMPLICIT_MEMORY": implicit_count, # 隐性记忆(MemorySummary节点数,需>=MIN_MEMORY_SUMMARY_COUNT) "EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计) "EPISODIC_MEMORY": episodic_count, # 情景记忆 "FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值) diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index 3122d282..43a58c5f 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -285,7 +285,7 @@ def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: Use try: # 查找用户 business_logger.debug(f"查找待激活用户: {user_id_to_activate}") - db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate) + db_user = user_repository.get_user_by_id_regardless_active(db, user_id=user_id_to_activate) if not db_user: business_logger.warning(f"用户不存在: {user_id_to_activate}") raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py index fd8f25f3..5a766a72 100644 --- a/api/app/services/workflow_import_service.py +++ b/api/app/services/workflow_import_service.py @@ -69,6 +69,7 @@ class WorkflowImportService: edges=workflow_config.edges, nodes=workflow_config.nodes, variables=workflow_config.variables, + features=workflow_config.features, warnings=workflow_config.warnings, errors=workflow_config.errors ) @@ -95,7 +96,8 @@ class WorkflowImportService: workflow_config=WorkflowConfigCreate( nodes=config["nodes"], edges=config["edges"], - variables=config["variables"] + variables=config["variables"], + features=config.get("features", {}) ) ) ) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 13267078..0d282d78 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -16,7 +16,6 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config -from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution @@ -26,7 +25,7 @@ from app.repositories.workflow_repository import ( WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest, FileInput +from app.schemas import DraftRunRequest, FileInput, FileType from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService @@ -453,22 +452,70 @@ class WorkflowService: "success_rate": completed / total if total > 0 else 0 } + async def _resolve_variables_file_defaults( + self, + variables: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Convert FileInput-format defaults in workflow variables to full FileObject dicts.""" + from app.core.workflow.utils.file_processor import ( + resolve_local_file_object_dict, + fetch_remote_file_meta, + ) + + async def _resolve_one(item: dict) -> dict | None: + if not isinstance(item, dict) or item.get("is_file"): + return item + transfer_method = item.get("transfer_method", "remote_url") + file_type = FileType.trans(item.get("type", "document")) + origin_file_type = item.get("file_type") or file_type + if transfer_method == "remote_url": + url = item.get("url", "") + return await fetch_remote_file_meta(url, file_type, origin_file_type) if url else None + else: + return resolve_local_file_object_dict(self.db, item.get("upload_file_id"), file_type, origin_file_type) + + result = [] + for var_def in variables: + var_type = var_def.get("type", "") + default = var_def.get("default") + if var_type == "file" and isinstance(default, dict) and not default.get("is_file"): + var_def = {**var_def, "default": await _resolve_one(default)} + elif var_type == "array[file]" and isinstance(default, list): + resolved = [] + for item in default: + r = await _resolve_one(item) + if r is not None: + resolved.append(r) + var_def = {**var_def, "default": resolved} + result.append(var_def) + return result + async def _handle_file_input(self, files: list[FileInput]): if not files: return [] + from app.core.workflow.utils.file_processor import ( + resolve_local_file_object_dict, + build_file_object_dict_from_meta, + fetch_remote_file_meta, + ) + files_struct = [] for file in files: - files_struct.append( - FileObject( - type=file.type, - url=await self.multimodal_service.get_file_url(file), - transfer_method=file.transfer_method, - file_id=str(file.upload_file_id) if file.upload_file_id else None, - origin_file_type=file.file_type, - is_file=True - ).model_dump() - ) + url = await self.multimodal_service.get_file_url(file) + file_type = str(file.type) + origin_file_type = file.file_type or file_type + + if file.transfer_method.value == "local_file" and file.upload_file_id: + fo = resolve_local_file_object_dict(self.db, file.upload_file_id, file_type, origin_file_type) + files_struct.append(fo or build_file_object_dict_from_meta( + file_type=file_type, transfer_method="local_file", + origin_file_type=origin_file_type, + file_id=str(file.upload_file_id), url=url, + file_name=None, file_size=None, file_ext=None, content_type=None, + )) + else: + files_struct.append(await fetch_remote_file_meta(url, file_type, origin_file_type)) return files_struct @staticmethod @@ -545,6 +592,12 @@ class WorkflowService: def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) user_rag_memory_id = "" + # 如果 storage_type 为 None,使用默认值 'neo4j' + if not storage_type: + storage_type = 'neo4j' + logger.warning( + f"Storage type not set for workspace {workspace_id}, using default: neo4j" + ) if storage_type == "rag": knowledge = knowledge_repository.get_knowledge_by_name( db=self.db, @@ -659,6 +712,26 @@ class WorkflowService: input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) + # 新会话时写入开场白 + is_new_conversation = init_message_length == 0 + if is_new_conversation: + opening_cfg = feature_configs.get("opening_statement", {}) + if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"): + statement = opening_cfg["statement"] + suggested_questions = opening_cfg.get("suggested_questions", []) + if payload.variables: + for var_name, var_value in payload.variables.items(): + statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value)) + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="assistant", + content=statement, + meta_data={"suggested_questions": suggested_questions} + ) + # 注入到 conv_messages,让 LLM 感知开场白 + input_data["conv_messages"] = [{"role": "assistant", "content": statement}] + init_message_length = 1 + result = await execute_workflow( workflow_config=workflow_config_dict, input_data=input_data, @@ -696,12 +769,21 @@ class WorkflowService: content=human_message, meta_data=human_meta ) + # 过滤 citations + citations = result.get("citations", []) + citation_cfg = feature_configs.get("citation", {}) + filtered_citations = ( + citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else [] + ) + assistant_meta = {"usage": token_usage, "audio_url": None} + if filtered_citations: + assistant_meta["citations"] = filtered_citations self.conversation_service.add_message( message_id=message_id, conversation_id=conversation_id_uuid, role="assistant", content=assistant_message, - meta_data={"usage": token_usage, "audio_url": None} + meta_data=assistant_meta ) self.update_execution_status( execution.execution_id, @@ -720,6 +802,7 @@ class WorkflowService: ) logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id}," f" error: {result.get('error')}") + filtered_citations = [] # 返回增强的响应结构 return { @@ -734,7 +817,8 @@ class WorkflowService: "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") + "token_usage": result.get("token_usage"), + "citations": filtered_citations, } except Exception as e: @@ -825,6 +909,27 @@ class WorkflowService: input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() + + # 新会话时写入开场白 + is_new_conversation = init_message_length == 0 + if is_new_conversation: + opening_cfg = feature_configs.get("opening_statement", {}) + if isinstance(opening_cfg, dict) and opening_cfg.get("enabled") and opening_cfg.get("statement"): + statement = opening_cfg["statement"] + suggested_questions = opening_cfg.get("suggested_questions", []) + if payload.variables: + for var_name, var_value in payload.variables.items(): + statement = statement.replace(f"{{{{{var_name}}}}}", str(var_value)) + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="assistant", + content=statement, + meta_data={"suggested_questions": suggested_questions} + ) + # 注入到 conv_messages,让 LLM 感知开场白 + input_data["conv_messages"] = [{"role": "assistant", "content": statement}] + init_message_length = 1 + async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, @@ -852,7 +957,10 @@ class WorkflowService: for file in message["content"]: human_meta["files"].append({ "type": file.get("type"), - "url": file.get("url") + "url": file.get("url"), + "file_type": file.get("origin_file_type"), + "name": file.get("name"), + "size": file.get("size") }) if message["role"] == "assistant": assistant_message = message["content"] @@ -862,12 +970,21 @@ class WorkflowService: content=human_message, meta_data=human_meta ) + # 过滤 citations + citations = event.get("data", {}).get("citations", []) + citation_cfg = feature_configs.get("citation", {}) + filtered_citations = ( + citations if isinstance(citation_cfg, dict) and citation_cfg.get("enabled") else [] + ) + assistant_meta = {"usage": token_usage, "audio_url": None} + if filtered_citations: + assistant_meta["citations"] = filtered_citations self.conversation_service.add_message( message_id=message_id, conversation_id=conversation_id_uuid, role="assistant", content=assistant_message, - meta_data={"usage": token_usage, "audio_url": None} + meta_data=assistant_meta ) self.update_execution_status( execution.execution_id, @@ -875,6 +992,7 @@ class WorkflowService: output_data=event.get("data"), token_usage=token_usage.get("total_tokens", None) ) + event.setdefault("data", {})["citations"] = filtered_citations logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 90b5cf65..4034eb6d 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -480,21 +480,21 @@ def create_workspace_invite( try: # 检查权限 _check_workspace_admin_permission(db, workspace_id, user) - if settings.ENABLE_SINGLE_WORKSPACE: - # 检查被邀请用户是否已经在工作空间中 - from app.repositories import user_repository - invited_user = user_repository.get_user_by_email(db, invite_data.email) + # if settings.ENABLE_SINGLE_WORKSPACE: + # 检查被邀请用户是否已经在工作空间中 + from app.repositories import user_repository + invited_user = user_repository.get_user_by_email(db, invite_data.email) - if invited_user: - # 用户存在,检查是否已经是工作空间成员 - existing_member = workspace_repository.get_member_in_workspace( - db=db, - user_id=invited_user.id, - workspace_id=workspace_id - ) - if existing_member: - business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员") - raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS) + if invited_user: + # 用户存在,检查是否已经是工作空间成员 + existing_member = workspace_repository.get_member_in_workspace( + db=db, + user_id=invited_user.id, + workspace_id=workspace_id + ) + if existing_member: + business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员") + raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS) # 检查是否已有待处理的邀请 invite_repo = WorkspaceInviteRepository(db) diff --git a/api/app/tasks.py b/api/app/tasks.py index 72421a5f..5a71066a 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,4 +1,5 @@ import asyncio +import json import os import re import shutil @@ -44,6 +45,23 @@ from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) +# ── 预编译文件类型正则 & 常量 ────────────────────────────────── +AUDIO_PATTERN = re.compile( + r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", + re.IGNORECASE, +) +VIDEO_IMAGE_PATTERN = re.compile( + r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", + re.IGNORECASE, +) +DEFAULT_PARSE_LANGUAGE = "Chinese" +DEFAULT_PARSE_TO_PAGE = 100_000 +EMBEDDING_BATCH_SIZE = int(os.getenv("EMBEDDING_BATCH_SIZE", "20")) +# Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整 +EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3")) +# auto_questions LLM 并发调用的最大线程数 +AUTO_QUESTIONS_MAX_WORKERS = int(os.getenv("AUTO_QUESTIONS_MAX_WORKERS", "5")) + # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连 @@ -61,9 +79,9 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool | None: db=settings.REDIS_DB_CELERY_BACKEND, password=settings.REDIS_PASSWORD, decode_responses=True, - max_connections=10, + max_connections=100, socket_connect_timeout=5, - socket_timeout=5, + socket_timeout=10, retry_on_timeout=True, health_check_interval=30, ) @@ -160,28 +178,67 @@ def process_item(item: dict): return result +def _build_vision_model(file_path: str, db_knowledge): + """根据文件类型选择合适的视觉/音频模型,避免冗余初始化。""" + if AUDIO_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenSeq2txt( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + if VIDEO_IMAGE_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenCV( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + # 默认:使用知识库配置的 image2text 模型 + return QWenCV( + key=db_knowledge.image2text.api_keys[0].api_key, + model_name=db_knowledge.image2text.api_keys[0].model_name, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=db_knowledge.image2text.api_keys[0].api_base, + ) + + @celery_app.task(name="app.core.rag.tasks.parse_document") def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import importlib - import trio - importlib.reload(trio) - db = next(get_db()) # Manually call the generator db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: + progress_lines: list[str] = [f"{datetime.now().strftime('%H:%M:%S')} Task has been received."] + + def _progress_msg() -> str: + return "\n".join(progress_lines) + "\n" + + with get_db_context() as db: + try: + # Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) + db_document = db.query(Document).filter(Document.id == document_id).first() + if db_document is None: + raise ValueError(f"Document {document_id} not found") db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() + if db_knowledge is None: + raise ValueError(f"Knowledge {db_document.kb_id} not found") + # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.") start_time = time.time() db_document.progress = 0.0 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db_document.process_begin_at = datetime.now(tz=timezone.utc) db_document.process_duration = 0.0 db_document.run = 1 @@ -189,220 +246,195 @@ def parse_document(file_path: str, document_id: uuid.UUID): db.refresh(db_document) def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.") - # Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path, - re.IGNORECASE): - vision_model = QWenSeq2txt( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path, - re.IGNORECASE): - vision_model = QWenCV( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - else: - print(file_path) + # Prepare vision_model for parsing + vision_model = _build_vision_model(file_path, db_knowledge) from app.core.rag.app.naive import chunk res = chunk(filename=file_path, from_page=0, - to_page=100000, + to_page=DEFAULT_PARSE_TO_PAGE, callback=progress_callback, vision_model=vision_model, parser_config=db_document.parser_config, is_root=False) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.") db_document.progress = 0.8 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db.commit() db.refresh(db_document) # 2. Document vectorization and storage total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.") - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if db_document.parser_config.get("auto_questions", 0): - topn = db_document.parser_config["auto_questions"] - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", - {"topn": topn}) + if total_chunks == 0: + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.") + else: + total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE) + progress_per_batch = 0.2 / total_batches + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2.1 Delete document vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) + # 2.2 Vectorize and import batch documents + auto_questions_topn = db_document.parser_config.get("auto_questions", 0) + chat_model = None + if auto_questions_topn: + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + + # 预先构建所有 batch 的 chunks,保证 sort_id 全局有序 + all_batch_chunks: list[list[DocumentChunk]] = [] + + if auto_questions_topn: + # auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组 + # 构建 (global_idx, item) 列表 + indexed_items = list(enumerate(res)) + + def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]: + """为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)""" + global_idx, item = idx_item + content = item["content_with_weight"] + cached = get_llm_cache(chat_model.model_name, content, "question", + {"topn": auto_questions_topn}) if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", - {"topn": topn}) - chunks.append( - DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - else: - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + cached = question_proposal(chat_model, content, auto_questions_topn) + set_llm_cache(chat_model.model_name, content, cached, "question", + {"topn": auto_questions_topn}) + return global_idx, cached - # Bulk segmented vector import - vector_service.add_chunks(chunks) + # 并发调用 LLM 生成问题 + question_map: dict[int, str] = {} + with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor: + futures = {q_executor.submit(_generate_question, item): item[0] + for item in indexed_items} + for future in futures: + global_idx, cached = future.result() + question_map[global_idx] = cached - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" - db_document.progress_msg = progress_msg + progress_lines.append( + f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks " + f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).") + + # 按 batch 分组组装 DocumentChunk + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + cached = question_map[global_idx] + chunks.append( + DocumentChunk( + page_content=f"question: {cached} answer: {item['content_with_weight']}", + metadata=metadata)) + all_batch_chunks.append(chunks) + else: + # 无 auto_questions:直接构建 chunks + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + all_batch_chunks.append(chunks) + + # 并发提交 embedding + ES 写入,max_workers 控制模型 API 并发压力 + batch_errors: dict[int, Exception] = {} + + def _embed_and_store(batch_idx: int, batch_chunks: list[DocumentChunk]): + try: + vector_service.add_chunks(batch_chunks) + except Exception as exc: + logger.warning(f"[ParseDoc] batch {batch_idx} failed, retrying: {exc}") + try: + vector_service.add_chunks(batch_chunks) + except Exception as retry_exc: + logger.error(f"[ParseDoc] batch {batch_idx} retry failed: {retry_exc}", exc_info=True) + batch_errors[batch_idx] = retry_exc + + with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor: + futures = { + executor.submit(_embed_and_store, i, batch_chunks): i + for i, batch_chunks in enumerate(all_batch_chunks) + } + for future in futures: + future.result() + + # 如果有 batch 失败,汇总抛出 + if batch_errors: + failed_detail = "; ".join( + f"batch {i}: {type(err).__name__}: {err}" + for i, err in sorted(batch_errors.items()) + ) + raise RuntimeError(f"Embedding failed for {len(batch_errors)}/{total_batches} batch(es). {failed_detail}") + + # 所有 batch 完成后一次性更新进度 + db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态 + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} All {total_batches} batches embedded (workers={EMBEDDING_MAX_WORKERS}).") + db_document.progress_msg = _progress_msg() db_document.process_duration = time.time() - start_time db_document.run = 0 db.commit() db.refresh(db_document) # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.") db_document.chunk_num = total_chunks db_document.progress = 1.0 db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).") + db_document.progress_msg = _progress_msg() db_document.run = 0 db.commit() - # using graphrag + # GraphRAG: 异步派发到独立队列,不阻塞文档解析流程 if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): - graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) - with_resolution = graphrag_conf.get("resolution", False) - with_community = graphrag_conf.get("community", False) - - def callback(*args, msg=None, **kwargs): - nonlocal progress_msg - message = msg or (args[0] if args else "No message") - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n" - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to run graphrag.\n" - start_time = time.time() - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG enabled, dispatching async task.") + db_document.progress_msg = _progress_msg() db.commit() - db.refresh(db_document) - - task = { - "id": str(db_document.id), - "workspace_id": str(db_knowledge.workspace_id), - "kb_id": str(db_knowledge.id), - "parser_config": db_knowledge.parser_config, - } - - # init_graphrag - vts, _ = embedding_model.encode(["ok"]) - vector_size = len(vts[0]) - init_graphrag(task, vector_size) - - async def _run( - row: dict, - document_ids: list[str], - language: str, - parser_config: dict, - vector_service, - chat_model, - embedding_model, - callback, - with_resolution: bool = True, - with_community: bool = True - ) -> dict: - await trio.sleep(5) # Delay for 10 seconds - nonlocal progress_msg # Declare the use of an external progress_msg variable - result = await run_graphrag_for_kb( - row=row, - document_ids=document_ids, - language=language, - parser_config=parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" - return result - - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=[str(db_document.id)], - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - ) - - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) + build_graphrag_for_document.delay(str(document_id), str(db_knowledge.id)) result = f"parse document '{db_document.file_name}' processed successfully." + logger.info(f"[ParseDoc] document={document_id} file='{db_document.file_name}' done in {db_document.process_duration:.1f}s, chunks={total_chunks}") return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True) + if db_document is not None: + try: + db.rollback() + db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n" + db_document.run = 0 + db.commit() + except Exception: + logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True) + # db_document 可能处于 detached/expired 状态,用之前缓存的值或 document_id 兜底 + file_name = getattr(db_document, 'file_name', None) if db_document else None + return f"parse document '{file_name or document_id}' failed." @celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb") @@ -410,51 +442,44 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): """ build knowledge graph """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) import importlib import trio importlib.reload(trio) - db = next(get_db()) # Manually call the generator - db_documents = None - db_knowledge = None - try: - db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() - # 1. Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - # 2. get all document_ids from knowledge base - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.search_by_segment(document_id=None, query=None, pagesize=9999, page=1, asc=True) - document_ids = [str(item.id) for item in db_documents] + with get_db_context() as db: + try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") + return f"build knowledge graph failed: knowledge not found" + + if not (db_knowledge.parser_config and + db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)): + return f"build knowledge graph '{db_knowledge.name}' skipped: graphrag not enabled" + + db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() + document_ids = [str(doc.id) for doc in db_documents] + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2. using graphrag - if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) - def callback(*args, msg=None, **kwargs): - message = msg or (args[0] if args else "No message") - print(f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n") - - start_time = time.time() task = { "id": str(db_knowledge.id), "workspace_id": str(db_knowledge.workspace_id), @@ -467,14 +492,18 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): vector_size = len(vts[0]) init_graphrag(task, vector_size) - async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, - chat_model, embedding_model, callback, with_resolution: bool = True, - with_community: bool = True, ) -> dict: - result = await run_graphrag_for_kb( - row=row, + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG-KB] kb={kb_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + return await run_graphrag_for_kb( + row=task, document_ids=document_ids, - language=language, - parser_config=parser_config, + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, vector_service=vector_service, chat_model=chat_model, embedding_model=embedding_model, @@ -482,46 +511,97 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_resolution=with_resolution, with_community=with_community, ) - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n") - return result - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=document_ids, - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG-KB] kb={kb_id} done in {duration:.1f}s, result: {result}") + + return f"build knowledge graph '{db_knowledge.name}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG-KB] kb={kb_id} failed: {e}", exc_info=True) + return f"build knowledge graph failed: {e}" + + +@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_document") +def build_graphrag_for_document(document_id: str, knowledge_id: str): + """ + 为单个文档构建 GraphRAG,由 parse_document 异步派发。 + """ + import importlib + + import trio + importlib.reload(trio) + + with get_db_context() as db: + try: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first() + if db_document is None or db_knowledge is None: + logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found") + return f"build_graphrag_for_document failed: record not found" + + graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) + with_resolution = graphrag_conf.get("resolution", False) + with_community = graphrag_conf.get("community", False) + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + task = { + "id": document_id, + "workspace_id": str(db_knowledge.workspace_id), + "kb_id": str(db_knowledge.id), + "parser_config": db_knowledge.parser_config, + } + + # init_graphrag + vts, _ = embedding_model.encode(["ok"]) + vector_size = len(vts[0]) + init_graphrag(task, vector_size) + + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG] doc={document_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + await trio.sleep(5) + return await run_graphrag_for_kb( + row=task, + document_ids=[document_id], + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, + vector_service=vector_service, + chat_model=chat_model, + embedding_model=embedding_model, + callback=callback, + with_resolution=with_resolution, + with_community=with_community, ) - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n") - finally: - if db: - db.close() - print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)") + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG] doc={document_id} done in {duration:.1f}s") - result = f"build knowledge graph '{db_knowledge.name}' processed successfully." - return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to build knowledge grap:{str(e)}\n") - result = f"build knowledge grap '{db_knowledge.name}' failed." - return result - finally: - if db: - db.close() + # 更新文档进度信息 + db_document.progress_msg = (db_document.progress_msg or "") + \ + f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({duration:.1f}s)\n" + db.commit() + + return f"build_graphrag_for_document '{document_id}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG] doc={document_id} failed: {e}", exc_info=True) + return f"build_graphrag_for_document '{document_id}' failed: {e}" @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") @@ -529,10 +609,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): """ sync knowledge document and Document parsing, vectorization, and storage """ - db = next(get_db()) # Manually call the generator - db_knowledge = None - try: + with get_db_context() as db: + try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[SyncKB] knowledge={kb_id} not found") + return f"sync knowledge failed: knowledge not found" + # 1. get vector_service vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -667,7 +753,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during crawl: {e}") + logger.error(f"[SyncKB] Error during crawl: {e}", exc_info=True) case "Third-party": # Integration of knowledge bases from three parties yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") @@ -685,13 +771,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): # Get all files from all repos async def async_get_files(api_client: YuqueAPIClient): async with api_client as client: - print("\n=== Fetching repositories ===") repos = await client.get_user_repos() - print(f"Found {len(repos)} repositories:") all_files = [] for repo in repos: - # Get documents from repository - print(f"\n=== Fetching documents from '{repo.name}' ===") docs = await client.get_repo_docs(repo.id) all_files.extend(docs) return all_files @@ -837,7 +919,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch yuque: {e}", exc_info=True) if feishu_app_id: # Feishu Knowledge Base feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") @@ -999,19 +1081,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch feishu: {e}", exc_info=True) case _: # General - print(f"General: No synchronization needed\n") + logger.info(f"[SyncKB] kb={kb_id} type={db_knowledge.type}: no synchronization needed") result = f"sync knowledge '{db_knowledge.name}' processed successfully." return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to sync knowledge:{str(e)}\n") - result = f"sync knowledge '{db_knowledge.name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[SyncKB] kb={kb_id} failed: {e}", exc_info=True) + kb_name = db_knowledge.name if db_knowledge else kb_id + return f"sync knowledge '{kb_name}' failed: {e}" @celery_app.task(name="app.core.memory.agent.read_message", bind=True) @@ -1100,7 +1179,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s } -@celery_app.task(name="app.core.memory.agent.write_message", bind=True) +@celery_app.task(name="app.core.memory.agent.write_message", bind=True, acks_late=False) def write_message_task( self, end_user_id: str, @@ -1176,6 +1255,7 @@ def write_message_task( redis_client = get_sync_redis_client() lock = None + loop = None if redis_client is not None: lock = RedisFairLock( key=f"memory_write:{end_user_id}", @@ -1196,6 +1276,7 @@ def write_message_task( } try: + task_start_time = int(time.time()) loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1205,7 +1286,7 @@ def write_message_task( f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") try: - _r = get_sync_redis_client() + _r = redis_client if _r is not None: from datetime import timezone as _tz _now_utc = datetime.now(_tz.utc).isoformat() @@ -1219,6 +1300,7 @@ def write_message_task( return { "status": "SUCCESS", "result": result, + "start_at": task_start_time, "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, @@ -1252,7 +1334,8 @@ def write_message_task( logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # Gracefully shutdown the event loop to prevent # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ - _shutdown_loop_gracefully(loop) + if loop: + _shutdown_loop_gracefully(loop) # unused task @@ -1320,7 +1403,7 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: from app.models.app_model import App from app.models.end_user_model import EndUser from app.repositories.memory_increment_repository import write_memory_increment - from app.services.memory_storage_service import search_all + from app.services.memory_storage_service import search_all_batch with get_db_context() as db: try: @@ -1354,27 +1437,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: EndUser.workspace_id == workspace_id ).distinct().all() - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] + # 3. 批量查询所有宿主的记忆总量 + end_user_id_list = [str(eid) for (eid,) in end_users] + batch_result = await search_all_batch(end_user_id_list) - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) + total_num = sum(batch_result.values()) + end_user_details = [ + {"end_user_id": uid, "total": batch_result.get(uid, 0)} + for uid in end_user_id_list + ] # 4. 写入数据库 memory_increment = write_memory_increment( @@ -1437,7 +1508,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: from app.models.end_user_model import EndUser from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment - from app.services.memory_storage_service import search_all + from app.services.memory_storage_service import search_all_batch with get_db_context() as db: try: @@ -1495,28 +1566,15 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: EndUser.workspace_id == workspace_id ).distinct().all() - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] + # 3. 批量查询所有宿主的记忆总量 + end_user_id_list = [str(eid) for (eid,) in end_users] + batch_result = await search_all_batch(end_user_id_list) - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) + total_num = sum(batch_result.values()) + end_user_details = [ + {"end_user_id": uid, "total": batch_result.get(uid, 0)} + for uid in end_user_id_list + ] # 4. 写入数据库 memory_increment = write_memory_increment( @@ -1531,6 +1589,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "status": "SUCCESS", "total_num": total_num, "end_user_count": len(end_users), + "end_user_details": end_user_details, "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), }) @@ -2623,35 +2682,34 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ service = MemoryAgentService() - with get_db_context() as db: - for end_user_id in end_user_ids: - # 存在性检查:缓存有数据则跳过 - cached = await InterestMemoryCache.get_interest_distribution( + for end_user_id in end_user_ids: + # 存在性检查:缓存有数据则跳过 + cached = await InterestMemoryCache.get_interest_distribution( + end_user_id=end_user_id, + language=language, + ) + if cached is not None: + skipped += 1 + continue + + logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成") + try: + result = await service.get_interest_distribution_by_user( end_user_id=end_user_id, + limit=5, language=language, ) - if cached is not None: - skipped += 1 - continue - - logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成") - try: - result = await service.get_interest_distribution_by_user( - end_user_id=end_user_id, - limit=5, - language=language, - ) - await InterestMemoryCache.set_interest_distribution( - end_user_id=end_user_id, - language=language, - data=result, - expire=INTEREST_CACHE_EXPIRE, - ) - initialized += 1 - logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功") - except Exception as e: - failed += 1 - logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}") + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + expire=INTEREST_CACHE_EXPIRE, + ) + initialized += 1 + logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功") + except Exception as e: + failed += 1 + logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}") logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}") return { @@ -2935,4 +2993,270 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace } +# ─── User Metadata Extraction Task ─────────────────────────────────────────── + + +def _update_timestamps(existing: dict, new: dict, updated_at: dict, now: str, prefix: str = "") -> None: + """对比新旧元数据,更新变更字段的 _updated_at 时间戳。""" + for key, new_val in new.items(): + if key == "_updated_at": + continue + path = f"{prefix}.{key}" if prefix else key + old_val = existing.get(key) + + if isinstance(new_val, dict) and isinstance(old_val, dict): + _update_timestamps(old_val, new_val, updated_at, now, prefix=path) + elif old_val != new_val: + updated_at[path] = now + +@celery_app.task( + bind=True, + name='app.tasks.extract_user_metadata', + ignore_result=False, + max_retries=0, + acks_late=True, + time_limit=300, + soft_time_limit=240, +) +def extract_user_metadata_task( + self, + end_user_id: str, + statements: List[str], + config_id: Optional[str] = None, + language: str = "zh", +) -> Dict[str, Any]: + """异步提取用户元数据并写入数据库。 + + 在去重消歧完成后由编排器触发,使用独立 LLM 调用提取元数据。 + LLM 配置优先使用 config_id 对应的应用配置,失败时回退到工作空间默认配置。 + + Args: + end_user_id: 终端用户 ID + statements: 用户相关的 statement 文本列表 + config_id: 应用配置 ID(可选) + language: 语言类型 ("zh" 中文, "en" 英文) + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + logger.info( + f"[CELERY METADATA] Starting metadata extraction - end_user_id={end_user_id}, " + f"statements_count={len(statements)}, config_id={config_id}, language={language}" + ) + + async def _run() -> Dict[str, Any]: + from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor + from app.repositories.end_user_info_repository import EndUserInfoRepository + from app.repositories.end_user_repository import EndUserRepository + from app.services.memory_config_service import MemoryConfigService + + # 1. 获取 LLM 配置(应用配置 → 工作空间配置兜底)并创建 LLM client + with get_db_context() as db: + end_user_uuid = uuid.UUID(end_user_id) + + # 获取 workspace_id from end_user + end_user = EndUserRepository(db).get_by_id(end_user_uuid) + if not end_user: + return {"status": "FAILURE", "error": f"End user not found: {end_user_id}"} + + workspace_id = end_user.workspace_id + + config_service = MemoryConfigService(db) + memory_config = config_service.get_config_with_fallback( + memory_config_id=uuid.UUID(config_id) if config_id else None, + workspace_id=workspace_id, + ) + if not memory_config: + return {"status": "FAILURE", "error": "No LLM config available (app + workspace fallback failed)"} + + # 2. 创建 LLM client + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + factory = MemoryClientFactory(db) + if not memory_config.llm_id: + return {"status": "FAILURE", "error": "Memory config has no LLM model configured"} + llm_client = factory.get_llm_client(memory_config.llm_id) + + # 2.5 读取已有元数据和别名,传给 extractor 作为上下文 + existing_metadata = None + existing_aliases = None + try: + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + if info: + if info.meta_data: + existing_metadata = info.meta_data + existing_aliases = info.aliases if info.aliases else [] + logger.info(f"[CELERY METADATA] 已读取已有元数据和别名(aliases={existing_aliases})") + except Exception as e: + logger.warning(f"[CELERY METADATA] 读取已有数据失败(继续无上下文提取): {e}") + + # 3. 提取元数据和别名(传入已有数据作为上下文) + extractor = MetadataExtractor(llm_client=llm_client, language=language) + extract_result = await extractor.extract_metadata( + statements, + existing_metadata=existing_metadata, + existing_aliases=existing_aliases, + ) + + if not extract_result: + logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}") + return {"status": "SUCCESS", "result": "no_metadata_extracted"} + + user_metadata, aliases_to_add, aliases_to_remove = extract_result + logger.info(f"[CELERY METADATA] LLM 别名新增: {aliases_to_add}, 移除: {aliases_to_remove}") + + # 4. 清洗元数据、覆盖写入元数据和别名 + def clean_metadata(raw: dict) -> dict: + """递归移除空字符串、空列表、空字典。""" + result = {} + for k, v in raw.items(): + if v == "" or v == []: + continue + if isinstance(v, dict): + cleaned = clean_metadata(v) + if cleaned: + result[k] = cleaned + else: + result[k] = v + return result + + raw_dict = user_metadata.model_dump(exclude_none=True) if user_metadata else {} + logger.info(f"[CELERY METADATA] LLM 输出完整元数据: {json.dumps(raw_dict, ensure_ascii=False)}") + + cleaned = clean_metadata(raw_dict) if raw_dict else {} + logger.info(f"[CELERY METADATA] 清洗后元数据: {json.dumps(cleaned, ensure_ascii=False)}") + + from datetime import datetime as dt, timezone as tz + now = dt.now(tz.utc).isoformat() + + # 过滤别名中的占位名称,执行增量增删 + _PLACEHOLDER_NAMES = {"用户", "我", "user", "i"} + + def _filter_aliases(aliases_list): + seen = set() + result = [] + for a in aliases_list: + a_stripped = a.strip() + if a_stripped and a_stripped.lower() not in _PLACEHOLDER_NAMES and a_stripped.lower() not in seen: + result.append(a_stripped) + seen.add(a_stripped.lower()) + return result + + filtered_add = _filter_aliases(aliases_to_add) + filtered_remove = _filter_aliases(aliases_to_remove) + remove_lower = {a.lower() for a in filtered_remove} + + with get_db_context() as db: + end_user_uuid = uuid.UUID(end_user_id) + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + end_user = EndUserRepository(db).get_by_id(end_user_uuid) + + if info: + # 元数据覆盖写入 + if cleaned: + existing_meta = info.meta_data if info.meta_data else {} + updated_at = dict(existing_meta.get("_updated_at", {})) + _update_timestamps(existing_meta, cleaned, updated_at, now) + final = dict(cleaned) + final["_updated_at"] = updated_at + info.meta_data = final + logger.info("[CELERY METADATA] 覆盖写入元数据") + + # 别名增量增删:(已有 - remove) + add + old_aliases = info.aliases if info.aliases else [] + # 先移除 + merged = [a for a in old_aliases if a.strip().lower() not in remove_lower] + # 再追加(去重) + existing_lower = {a.strip().lower() for a in merged} + for a in filtered_add: + if a.lower() not in existing_lower: + merged.append(a) + existing_lower.add(a.lower()) + + if merged != old_aliases: + info.aliases = merged + # other_name 更新逻辑 + if merged and ( + not info.other_name + or info.other_name.strip().lower() in _PLACEHOLDER_NAMES + or info.other_name.strip().lower() in remove_lower + ): + info.other_name = merged[0] + if end_user and merged and ( + not end_user.other_name + or end_user.other_name.strip().lower() in _PLACEHOLDER_NAMES + or end_user.other_name.strip().lower() in remove_lower + ): + end_user.other_name = merged[0] + logger.info( + f"[CELERY METADATA] 别名增量更新: {old_aliases} - {filtered_remove} + {filtered_add} → {merged}" + ) + else: + # 没有 end_user_info 记录,创建一条 + from app.models.end_user_info_model import EndUserInfo + initial_aliases = filtered_add # 新记录只有 add,没有 remove + first_alias = initial_aliases[0] if initial_aliases else "" + if first_alias or cleaned: + new_info = EndUserInfo( + end_user_id=end_user_uuid, + other_name=first_alias or "", + aliases=initial_aliases, + meta_data=cleaned if cleaned else None, + ) + db.add(new_info) + if end_user and first_alias and ( + not end_user.other_name or end_user.other_name.strip().lower() in _PLACEHOLDER_NAMES + ): + end_user.other_name = first_alias + logger.info(f"[CELERY METADATA] 创建 end_user_info: other_name={first_alias}, aliases={initial_aliases}") + else: + return {"status": "SUCCESS", "result": "no_data_to_write"} + + db.commit() + + # 同步 PgSQL aliases 到 Neo4j 用户实体(PgSQL 为权威源) + final_aliases = info.aliases if info else initial_aliases + if final_aliases: + try: + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + neo4j_connector = Neo4jConnector() + cypher = """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] + SET e.aliases = $aliases + """ + await neo4j_connector.execute_query( + cypher, end_user_id=end_user_id, aliases=final_aliases + ) + await neo4j_connector.close() + logger.info(f"[CELERY METADATA] Neo4j 用户实体 aliases 已同步: {final_aliases}") + except Exception as neo4j_err: + logger.warning(f"[CELERY METADATA] Neo4j aliases 同步失败(不影响主流程): {neo4j_err}") + + return {"status": "SUCCESS", "result": "metadata_and_aliases_written"} + + loop = None + try: + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + elapsed = time.time() - start_time + result["elapsed_time"] = elapsed + result["task_id"] = self.request.id + logger.info(f"[CELERY METADATA] Task completed - elapsed={elapsed:.2f}s, result={result.get('result')}") + return result + + except Exception as e: + elapsed = time.time() - start_time + logger.error(f"[CELERY METADATA] Task failed - elapsed={elapsed:.2f}s, error={e}", exc_info=True) + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed, + "task_id": self.request.id, + } + finally: + if loop: + _shutdown_loop_gracefully(loop) + + # unused task \ No newline at end of file diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index bc03bb28..4f88fb4d 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -153,7 +153,8 @@ def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig: edges=config_dict.get("edges", []), variables=config_dict.get("variables", []), execution_config=config_dict.get("execution_config", {}), - triggers=config_dict.get("triggers", []) + triggers=config_dict.get("triggers", []), + features=config_dict.get("features", {}) ) return config diff --git a/api/app/utils/performance_timer.py b/api/app/utils/performance_timer.py index 6b0ec5d6..04e52fb1 100644 --- a/api/app/utils/performance_timer.py +++ b/api/app/utils/performance_timer.py @@ -6,13 +6,13 @@ """ import time -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from app.core.logging_config import get_api_logger # 获取API专用日志器 api_logger = get_api_logger() - +# 同步的上下文管理器,使用@contextmanager修饰 @contextmanager def timer(label: str, user_count: int = 0): """上下文管理器:用于测量代码块执行时间 @@ -35,3 +35,27 @@ def timer(label: str, user_count: int = 0): elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 extra_info = f", 用户数: {user_count}" if user_count > 0 else "" api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}") + +# 异步的上下文管理器,使用@asynccontextmanager装饰 +@asynccontextmanager +async def async_timer(label: str, user_count: int = 0): + """异步上下文管理器:用于测量包含 await 的异步代码块执行时间 + + Args: + label: 统计标签,用于标识被测量的代码块 + user_count: 用户数,可选参数,用于记录处理的用户数量 + + Usage: + async with async_timer("获取用户列表"): + users = await get_users() + + async with async_timer("批量处理", user_count=len(user_ids)): + await process_users(user_ids) + """ + start = time.perf_counter() + try: + yield + finally: + elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 + extra_info = f", 用户数: {user_count}" if user_count > 0 else "" + api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}") diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index a86ba46e..b192c129 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,14 +1,21 @@ -import redis -import uuid -import time +import logging import threading +import time +import uuid + +import redis +from redis.exceptions import ( + ConnectionError, + TimeoutError, + RedisError, +) UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) -else - return 0 end + +return 0 """ RENEW_SCRIPT = """ @@ -19,38 +26,44 @@ else end """ -CLEANUP_DEAD_HEAD_SCRIPT = """ +ACQUIRE_SCRIPT = """ local queue_key = KEYS[1] local lock_key = KEYS[2] -local first = redis.call("lindex", queue_key, 0) -if not first then - return 0 +local client_id = ARGV[1] +local expire = tonumber(ARGV[2]) +local time_out = tonumber(ARGV[3]) + +local now = tonumber(redis.call("time")[1]) + +if redis.call("zscore", queue_key, client_id) == false then + redis.call("zadd", queue_key, now, client_id) end -if redis.call("exists", lock_key) == 1 then - return 0 +local expired = redis.call("zrangebyscore", queue_key, 0, now - time_out) + +for _, v in ipairs(expired) do + redis.call("zrem", queue_key, v) end -redis.call("lpop", queue_key) -return 1 -""" +local first = redis.call("zrange", queue_key, 0, 0)[1] +if first == client_id then -SAFE_RELEASE_QUEUE_SCRIPT = """ -local queue_key = KEYS[1] -local value = ARGV[1] + if redis.call("set", lock_key, client_id, "NX", "EX", expire) then + redis.call("zrem", queue_key, client_id) + return 1 + end -local first = redis.call("lindex", queue_key, 0) -if first == value then - redis.call("lpop", queue_key) - return 1 + if redis.call("get", lock_key) == client_id then + redis.call("expire", lock_key, expire) + return 1 + end end return 0 """ def _ensure_str(val): - """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" if val is None: return None if isinstance(val, bytes): @@ -59,18 +72,27 @@ def _ensure_str(val): class RedisFairLock: + # ZOMBIE CLEAN BUFFER + CLEANUP_BUFFER = 30 + # Redis 操作失败时的最大重试次数 + MAX_RETRIES = 3 + # 重试间隔基数(秒),实际间隔 = base * 2^attempt(指数退避) + RETRY_BACKOFF_BASE = 0.1 + + _logger = logging.getLogger(__name__) + def __init__( self, key: str, redis_client: redis.StrictRedis, expire: int = 30, - retry_interval: float = 0.05, + retry_interval: float = 1, timeout: float = 600, auto_renewal: bool = True ): self.key = key - self.queue_key = f"{key}:queue" - self.value = str(uuid.uuid4()) + self.queue_key = f"{key}:zset" + self.value = f"{uuid.uuid4().hex}:{int(time.time())}" self.expire = expire self.retry_interval = retry_interval self.timeout = timeout @@ -80,28 +102,56 @@ class RedisFairLock: self._renew_thread = None self._stop_renew = threading.Event() + def _exec_with_retry(self, func, *args, raise_on_fail=True, **kwargs): + """ + 带指数退避重试的 Redis 操作执行器。 + + 对 ConnectionError / TimeoutError 自动重试,其他异常直接抛出。 + """ + last_err = None + for attempt in range(self.MAX_RETRIES): + try: + return func(*args, **kwargs) + except (ConnectionError, TimeoutError) as e: + last_err = e + wait = self.RETRY_BACKOFF_BASE * (2 ** attempt) + self._logger.warning( + f"[RedisFairLock] Redis error on attempt {attempt + 1}/{self.MAX_RETRIES} " + f"for key={self.key}: {e}, retrying in {wait:.2f}s" + ) + time.sleep(wait) + except RedisError: + raise + if raise_on_fail: + raise last_err + return None + def acquire(self): start = time.time() - self.redis.rpush(self.queue_key, self.value) - while True: - first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + ok = self._exec_with_retry( + self.redis.eval, + ACQUIRE_SCRIPT, + 2, + self.queue_key, + self.key, + self.value, + str(self.expire), + str(self.timeout + self.CLEANUP_BUFFER), + ) - if first == self.value: - ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) - if ok: - self._locked = True - - if self.auto_renewal: - self._start_renewal() - return True - - if first: - self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + if ok == 1: + self._locked = True + if self.auto_renewal: + self._start_renewal() + return True if time.time() - start > self.timeout: - self.redis.lrem(self.queue_key, 0, self.value) + self._exec_with_retry( + self.redis.zrem, self.queue_key, self.value, + raise_on_fail=False, + ) return False time.sleep(self.retry_interval) @@ -112,13 +162,17 @@ class RedisFairLock: if self._stop_renew.is_set(): break - self.redis.eval( + success = self._exec_with_retry( + self.redis.eval, RENEW_SCRIPT, 1, self.key, self.value, - str(self.expire) + str(self.expire), + raise_on_fail=False, ) + if not success: + break def _start_renewal(self): self._stop_renew = threading.Event() @@ -137,9 +191,10 @@ class RedisFairLock: if self.auto_renewal: self._stop_renewal() - self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) - - self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._exec_with_retry( + self.redis.eval, UNLOCK_SCRIPT, 1, self.key, self.value, + raise_on_fail=False, + ) self._locked = False @@ -151,4 +206,3 @@ class RedisFairLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() - diff --git a/api/app/version_info.json b/api/app/version_info.json index b4f6976f..a094b64c 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,72 @@ { + "v0.3.0": { + "introduction": { + "codeName": "破晓", + "releaseDate": "2026-4-15", + "upgradePosition": "🐻 全面升级应用工作流、记忆智能与系统稳健性,引入版本化API、多模态记忆感知及大量工作流增强,打造更可靠、精准的 MemoryBear", + "coreUpgrades": [ + "1. 应用与API增强
* 版本化API调用支持:对外服务API支持指定版本调用
* 工作流检查清单:新增结构化验证步骤
* 深度思考参数精准控制:仅向支持深度推理的模型发送思考参数
* 提示器模型返回优化:优化提示器模型响应处理", + "2. 记忆智能 🧠
* 多模态记忆感知Agent:支持多模态记忆读取与写入
* OpenClaw内置工具:新增内置工具扩展Agent工具集", + "3. 用户体验 🎨
* 流式渲染稳定性优化:解决LLM流式输出页面抖动问题
* 记忆中枢更名:「记忆相关」更名为「记忆中枢」", + "4. 工作流改进 ⚙️
* 三级变量模板转换:支持三级变量解析
* VL模型Token统计:修复模型组合中VL模型Token未统计问题
* 导入工作流功能特性同步:正确同步开场白、引用等属性
* 会话变量名称唯一性校验:防止变量名冲突
* 文件类型提取修复:正确提取file.type信息
* 条件分支显示修复:值为0或会话变量时正确渲染
* Object/Array校验规则:防止JSON序列化错误
* HTTP请求Body字段修正:body字段从name改为key", + "5. 知识库 📚
* Embedding Token截断安全边界:统一添加8000 token截断,优化Excel独立chunk处理", + "6. 稳健性与缺陷修复 🔧
* 原子性更新与批量访问失败修复
* 对话别名提取错误修复
* 工作流别名提取修正(区分用户和AI回复)
* RAG记忆分页数据修复
* 隐式记忆详情显示修复
* 向量查询驱动关闭异常修复
* 用户管理启停异常修复
* 模型列表筛选不一致修复", + "
", + "v0.3.0 标志着 MemoryBear 向生产成熟度迈出坚实一步。后续版本将持续深化工作流表达力、记忆检索精度和跨模态理解能力,强化复杂Agent编排支持,稳固大规模生产部署基础。", + "
", + "MemoryBear — 破晓 🐻✨" + ] + }, + "introduction_en": { + "codeName": "PoXiao", + "releaseDate": "2026-4-15", + "upgradePosition": "🐻 Comprehensive upgrades across application workflows, memory intelligence, and system robustness — introducing versioned APIs, multimodal memory perception, and extensive workflow enhancements for a more reliable MemoryBear", + "coreUpgrades": [ + "1. Application & API Enhancements
* Versioned API Support: External APIs now support version-specific calls
* Workflow Checklist: Structured validation steps before deployment
* Deep Thinking Parameter Control: Only send thinking params to supported models
* Prompt Optimizer Return Optimization: Improved prompt optimizer response handling", + "2. Memory Intelligence 🧠
* Multimodal Memory Perception Agent: Read/write multimodal memory
* OpenClaw Built-in Tool: New built-in tool for agent operations", + "3. User Experience 🎨
* Streaming Render Stabilization: Eliminated page jitter during LLM output
* Memory Hub Renaming: Renamed to better reflect central memory role", + "4. Workflow Improvements ⚙️
* Three-Level Variable Template Conversion: Support for three-level variable resolution
* VL Model Token Tracking: Fixed token tracking for VL models in model groups
* Imported Workflow Feature Sync: Properly sync opening messages, citations, etc.
* Session Variable Name Uniqueness: Prevent variable name conflicts
* File Type Extraction Fix: Correctly extract file.type information
* Condition Branch Display Fix: Correct rendering for value 0 or session variables
* Object/Array Validation Rules: Prevent JSON serialization save errors
* HTTP Request Body Key Fix: Body field uses key instead of name", + "5. Knowledge Base 📚
* Embedding Token Truncation Safety: Unified 8000-token boundary, optimized Excel chunk processing", + "6. Robustness & Bug Fixes 🔧
* Atomic update & batch access failure fixes
* Conversation alias extraction fix
* Workflow alias extraction correction (user vs AI distinction)
* RAG memory pagination fix
* Implicit memory detail display fix
* Vector query driver closed exception fix
* User management enable/disable fix
* Model list filter inconsistency fix", + "
", + "v0.3.0 marks a meaningful step toward production maturity for MemoryBear. Upcoming releases will deepen workflow expressiveness, memory retrieval precision, and cross-modal understanding while strengthening complex agent orchestration and large-scale deployment foundations.", + "
", + "MemoryBear — Daybreak 🐻✨" + ] + } + }, + "v0.2.10": { + "introduction": { + "codeName": "炼剑", + "releaseDate": "2026-4-8", + "upgradePosition": "🐻 全面强化工作流引擎、引入 Agent 深度思考模式与多模态记忆读取,百炼成锋,剑指生产就绪", + "coreUpgrades": [ + "1. 工作流引擎增强
* 会话变量文件格式支持:支持文件类型值及本地/远程默认值配置
* 列表操作节点:新增专用列表操作节点
* 模板转换支持 HTML:扩展富内容渲染能力
* 表单返回与提交:工作流返回交互式表单,前端支持提交
* HTTP 节点 XML 响应:拓宽企业级 API 集成兼容性
* 开场白与文件引用:支持配置开场白及附件引用
* 模板转换三级变量:支持深层嵌套变量访问
* 节点连线添加按钮:连线处新增内联添加按钮", + "2. Agent 智能 🧠
* Agent 深度思考模式:支持更充分的推理以产出高质量回答
* 模型深度思考特性开关:模型级特性标识与应用级开关控制", + "3. 记忆系统升级 📚
* 用户记忆库分页:支持大规模记忆集合分页浏览
* RAG 用户记忆数据结构刷新:后端 API 数据结构重新设计
* 多模态记忆读取:支持检索图像、音频等非文本记忆
* 语义剪枝阈值提示文案:显示描述性区间标签", + "4. 前端与体验 🎨
* 技能工具删除状态展示:工具列表显示删除状态标识
* 仪表盘日环比数据:关键指标增加与昨日对比数据", + "5. 稳健性与缺陷修复 🔧
* 参数提取空值处理:优雅处理缺失数据
* Token 消耗展示优化:确保用量报告准确
* 模型参数负值修复:明确参数范围定义
* 应用共享删除同步:正确更新所有共享记录
* 记忆写入任务排序:按时间戳顺序执行
* 多模态模型缺失优雅处理:不再中断感知记忆写入
* 自定义工具 Number 变量传递:解决类型转换问题
* 集群子代理保存后显示:修复未反显问题
* 记忆开启后流式输出修复:解决字符串序列化问题", + "
", + "v0.2.10 标志着平台向生产成熟度迈出的重要一步。深度思考、交互式表单工作流与多模态记忆的结合展现了平台从记忆存储向综合认知基础设施的演进。我们期待 4 月 17 日 v0.3.0 发布会,届时将带来更深层的 Agent 推理能力、多智能体协作功能及记忆智能管线的进一步优化。剑已炼成,只待出鞘。", + "MemoryBear — 百炼成锋 🐻✨" + ] + }, + "introduction_en": { + "codeName": "LianJian", + "releaseDate": "2026-4-8", + "upgradePosition": "🐻 Comprehensive workflow engine enhancements, Agent deep thinking mode, and multimodal memory reading — forging the blade for production readiness", + "coreUpgrades": [ + "1. Workflow Engine Enhancements
* Session Variable File Support: File-type values with local/remote defaults
* List Operation Node: Dedicated node for array manipulation
* Template Conversion HTML Support: Rich-content rendering
* Form Return & Submission: Interactive forms in workflow conversations
* HTTP Node XML Response: Enterprise API integration compatibility
* Opening Remarks & File References: Configurable conversation openers
* Template Conversion Three-Level Variables: Deep nested variable access
* Node Connection Add Button: Inline add button on connections", + "2. Agent Intelligence 🧠
* Agent Deep Thinking Mode: Thorough reasoning for complex queries
* Model Deep Thinking Feature Toggle: Model-level flag with per-app control", + "3. Memory System Upgrades 📚
* User Memory Pagination: Paginated browsing for large collections
* RAG User Memory Data Structure Refresh: Redesigned backend API contracts
* Multimodal Memory Reading: Retrieval of image, audio, and non-text memory
* Semantic Pruning Threshold Hints: Descriptive range labels for configuration", + "4. Frontend & Usability 🎨
* Skill Tool Deletion Status Display: Deletion indicators in tool list
* Dashboard Day-over-Day Comparison: Key metrics with yesterday comparison", + "5. Robustness & Bug Fixes 🔧
* Parameter Extraction Null Handling: Graceful handling of missing data
* Token Consumption Display Optimization: Accurate usage reporting
* Model Parameter Negative Value Fix: Clear parameter range definitions
* App Share Deletion Sync: Correct update of all share records
* Memory Write Task Ordering: Chronological execution per end_user
* Multimodal Model Missing Graceful Handling: No more interrupted writes
* Custom Tool Number Variable Pass-through: Type coercion fix
* Cluster Sub-Agent Display After Save: Fixed UI reflection
* Memory-Enabled Streaming Output Fix: String serialization resolved", + "
", + "v0.2.10 marks a significant step toward production maturity. The combination of deep thinking, interactive form workflows, and multimodal memory demonstrates the platform's evolution from memory storage to comprehensive cognitive infrastructure. We look forward to the v0.3.0 launch on April 17, bringing deeper agent reasoning, multi-agent collaboration, and further memory intelligence refinements. The blade has been forged — now it's time to wield it.", + "MemoryBear — Forging the Blade 🐻✨" + ] + } + }, "v0.2.8": { "introduction": { "codeName": "景玉", diff --git a/web/package.json b/web/package.json index 0284f397..b41ab9b5 100644 --- a/web/package.json +++ b/web/package.json @@ -16,6 +16,7 @@ "@codemirror/lang-cpp": "^6.0.3", "@codemirror/lang-java": "^6.0.2", "@codemirror/lang-javascript": "^6.2.4", + "@codemirror/lang-json": "^6.0.2", "@codemirror/lang-python": "^6.2.1", "@codemirror/lang-rust": "^6.0.2", "@codemirror/state": "^6.5.4", diff --git a/web/src/App.tsx b/web/src/App.tsx index a10f9409..1af38372 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -16,7 +16,7 @@ import { ConfigProvider, App as AntdApp } from 'antd'; -import { useTranslation } from 'react-i18next'; +import i18n from 'i18next'; import { lightTheme } from './styles/antdThemeConfig.ts' import router from './routes'; @@ -29,11 +29,58 @@ import 'dayjs/plugin/utc' import { cookieUtils } from './utils/request'; import { useUser } from '@/store/user'; +import menuJson from '@/store/menu.json'; + +type MenuEntry = { path: string; i18nKey: string }; + +function flattenMenuEntries(list: any[]): MenuEntry[] { + const result: MenuEntry[] = []; + for (const item of list) { + if (item.path && item.i18nKey && item.type !== 'group') result.push({ path: item.path, i18nKey: item.i18nKey }); + if (item.subs?.length) result.push(...flattenMenuEntries(item.subs)); + } + return result; +} + +const menuEntries: MenuEntry[] = flattenMenuEntries([...menuJson.manage, ...menuJson.space]); + +function pathMatches(pattern: string, path: string): boolean { + if (pattern === path) return true; + if (pattern.includes(':')) { + return new RegExp('^' + pattern.replace(/:[\w-]+/g, '[^/]+') + '$').test(path); + } + return false; +} + +function getPageTitle(pathname: string): string { + const appName = i18n.t('memoryBear'); + const entry = menuEntries.find(e => pathMatches(e.path, pathname)); + if (!entry) return appName; + return `${i18n.t(entry.i18nKey)} - ${appName}`; +} + +const SKIP_TITLE_PATTERNS = [ + '/user-memory/detail/:id/:type', + '/forgetting-engine/:id', + '/memory-extraction-engine/:id', + '/emotion-engine/:id', + '/reflection-engine/:id', +]; + + + function App() { - const { t } = useTranslation(); const { locale, language, timeZone } = useI18n() const { checkJump } = useUser(); + useEffect(() => { + const unsubscribe = router.subscribe(({ location }) => { + if (SKIP_TITLE_PATTERNS.some(p => pathMatches(p, location.pathname))) return; + document.title = getPageTitle(location.pathname); + }); + return () => unsubscribe(); + }, []) + useEffect(() => { const authToken = cookieUtils.get('authToken') if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump') && !window.location.hash.includes('#/invite-register')) { @@ -44,7 +91,9 @@ function App() { }, []) useEffect(() => { - document.title = t('memoryBear') + if (!SKIP_TITLE_PATTERNS.some(p => pathMatches(p, router.state.location.pathname))) { + document.title = getPageTitle(router.state.location.pathname) + } dayjs.locale(language) localStorage.setItem('language', language) }, [language]) diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 05200221..52384d06 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -68,7 +68,7 @@ export const getModelTypeList = async () => { return response as any[]; }; // 获取模型列表 -export const getModelList = async (types: string[], pageInfo: PageRequest) => { +export const getModelList = async (pageInfo: PageRequest, types?: string[]) => { const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, type: types?.join(','), is_active: true }); return response as any; }; diff --git a/web/src/api/package.ts b/web/src/api/package.ts new file mode 100644 index 00000000..da52d355 --- /dev/null +++ b/web/src/api/package.ts @@ -0,0 +1,14 @@ +import { request } from '@/utils/request' + +import type { Package } from '@/views/Package/types' + +export const SYS_API_PREFIX = '/sys'; +// 套餐列表 +export const getPackageListUrl = `${SYS_API_PREFIX}/package-plans` +export const getPackageList = (query: { category: Package['category']; status: boolean; }) => { + return request.get(getPackageListUrl, query) +} +// 获取套餐详情 +export const getPackageDetail = (package_plan_id: string) => { + return request.get(`${SYS_API_PREFIX}/package-plans/${package_plan_id}`) +} \ No newline at end of file diff --git a/web/src/assets/images/common/return.svg b/web/src/assets/images/common/return.svg deleted file mode 100644 index cb8166c0..00000000 --- a/web/src/assets/images/common/return.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - 退出 - - - - - - - - - - - - - - \ No newline at end of file diff --git a/web/src/assets/images/conversation/compress.svg b/web/src/assets/images/conversation/compress.svg new file mode 100644 index 00000000..640d80ba --- /dev/null +++ b/web/src/assets/images/conversation/compress.svg @@ -0,0 +1,18 @@ + + + 编组 35 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/conversation/expand.svg b/web/src/assets/images/conversation/expand.svg new file mode 100644 index 00000000..8cc87d99 --- /dev/null +++ b/web/src/assets/images/conversation/expand.svg @@ -0,0 +1,15 @@ + + + 编组 36 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/audio_disabled.svg b/web/src/assets/images/file/audio_disabled.svg new file mode 100644 index 00000000..93d83a0a --- /dev/null +++ b/web/src/assets/images/file/audio_disabled.svg @@ -0,0 +1,13 @@ + + + 音乐 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/csv_disabled.svg b/web/src/assets/images/file/csv_disabled.svg new file mode 100644 index 00000000..29add1f6 --- /dev/null +++ b/web/src/assets/images/file/csv_disabled.svg @@ -0,0 +1,18 @@ + + + 编组 57 + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/excel_disabled.svg b/web/src/assets/images/file/excel_disabled.svg new file mode 100644 index 00000000..5e2136e9 --- /dev/null +++ b/web/src/assets/images/file/excel_disabled.svg @@ -0,0 +1,17 @@ + + + Excel + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/html_disabled.svg b/web/src/assets/images/file/html_disabled.svg new file mode 100644 index 00000000..fa237301 --- /dev/null +++ b/web/src/assets/images/file/html_disabled.svg @@ -0,0 +1,17 @@ + + + Word + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/json_disabled.svg b/web/src/assets/images/file/json_disabled.svg new file mode 100644 index 00000000..267e2b46 --- /dev/null +++ b/web/src/assets/images/file/json_disabled.svg @@ -0,0 +1,14 @@ + + + JSON + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/md_disabled.svg b/web/src/assets/images/file/md_disabled.svg new file mode 100644 index 00000000..8fe81fe7 --- /dev/null +++ b/web/src/assets/images/file/md_disabled.svg @@ -0,0 +1,19 @@ + + + PDF + + + + + + + + + + MD + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/pause.svg b/web/src/assets/images/file/pause.svg new file mode 100644 index 00000000..0e26ece0 --- /dev/null +++ b/web/src/assets/images/file/pause.svg @@ -0,0 +1,16 @@ + + + 播放 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/pdf_disabled.svg b/web/src/assets/images/file/pdf_disabled.svg new file mode 100644 index 00000000..950edcb8 --- /dev/null +++ b/web/src/assets/images/file/pdf_disabled.svg @@ -0,0 +1,20 @@ + + + PDF + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/play.svg b/web/src/assets/images/file/play.svg new file mode 100644 index 00000000..f2ff9cb7 --- /dev/null +++ b/web/src/assets/images/file/play.svg @@ -0,0 +1,28 @@ + + + 播放 + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/ppt_disabled.svg b/web/src/assets/images/file/ppt_disabled.svg new file mode 100644 index 00000000..f3da453e --- /dev/null +++ b/web/src/assets/images/file/ppt_disabled.svg @@ -0,0 +1,14 @@ + + + file-ppt-2-fill + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/txt_disabled.svg b/web/src/assets/images/file/txt_disabled.svg new file mode 100644 index 00000000..100565ce --- /dev/null +++ b/web/src/assets/images/file/txt_disabled.svg @@ -0,0 +1,14 @@ + + + txt + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/video_disabled.svg b/web/src/assets/images/file/video_disabled.svg new file mode 100644 index 00000000..f8f71c2a --- /dev/null +++ b/web/src/assets/images/file/video_disabled.svg @@ -0,0 +1,16 @@ + + + 编组 59 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/file/word_disabled.svg b/web/src/assets/images/file/word_disabled.svg new file mode 100644 index 00000000..d4f9e6ec --- /dev/null +++ b/web/src/assets/images/file/word_disabled.svg @@ -0,0 +1,15 @@ + + + Word + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/checkList.svg b/web/src/assets/images/workflow/checkList.svg new file mode 100644 index 00000000..169743dc --- /dev/null +++ b/web/src/assets/images/workflow/checkList.svg @@ -0,0 +1,16 @@ + + + 参与 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/features.svg b/web/src/assets/images/workflow/features.svg index 2ff48584..bd31b107 100644 --- a/web/src/assets/images/workflow/features.svg +++ b/web/src/assets/images/workflow/features.svg @@ -1,12 +1,14 @@ 参与 - - + + - - + + + + diff --git a/web/src/assets/images/workflow/list-operator.svg b/web/src/assets/images/workflow/list-operator.svg new file mode 100644 index 00000000..8091c04b --- /dev/null +++ b/web/src/assets/images/workflow/list-operator.svg @@ -0,0 +1,19 @@ + + + 编组 13 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/BtnTabs/index.tsx b/web/src/components/BtnTabs/index.tsx index 772a4c8d..8a6e670b 100644 --- a/web/src/components/BtnTabs/index.tsx +++ b/web/src/components/BtnTabs/index.tsx @@ -24,10 +24,11 @@ interface BtnTabsProps { onChange: (key: string) => void; /** Optional extra class name for the container */ className?: string; + variant?: 'outline' | 'borderless' } /** Button-style tab switcher — renders tabs as pill-shaped buttons with active highlight */ -const BtnTabs: FC = ({ items, activeKey, onChange, className }) => { +const BtnTabs: FC = ({ items, activeKey, onChange, className, variant = 'borderless' }) => { return ( {items.map((tab) => ( @@ -35,8 +36,9 @@ const BtnTabs: FC = ({ items, activeKey, onChange, className }) => key={tab.key} onClick={() => onChange(tab.key)} className={clsx('rb:px-2 rb:py-1 rb:rounded-[13px] rb:text-[12px] rb:leading-4.5 rb:cursor-pointer', { - 'rb:bg-[#F6F6F6]': activeKey !== tab.key, - 'rb:bg-[#171719] rb:text-white': activeKey === tab.key, + 'rb:bg-[#F6F6F6]': activeKey !== tab.key && variant === 'borderless', + 'rb-border rb:bg-white': activeKey !== tab.key && variant === 'outline', + 'rb:bg-[#171719] rb:text-white rb:border-[#171719]': activeKey === tab.key, })} > {tab.label} diff --git a/web/src/components/ButtonCheckbox/index.tsx b/web/src/components/ButtonCheckbox/index.tsx index 8c52701b..0804a1b3 100644 --- a/web/src/components/ButtonCheckbox/index.tsx +++ b/web/src/components/ButtonCheckbox/index.tsx @@ -74,9 +74,9 @@ const ButtonCheckbox: FC = ({ onClick={handleChange} > {/* Display unchecked icon when not checked */} - {icon && !checked && } + {icon && !checked && {icon}} {/* Display checked icon when checked */} - {checkedIcon && checked && } + {checkedIcon && checked && {checkedIcon}} {children} ); diff --git a/web/src/components/Chat/AudioPlayer.tsx b/web/src/components/Chat/AudioPlayer.tsx new file mode 100644 index 00000000..766c8deb --- /dev/null +++ b/web/src/components/Chat/AudioPlayer.tsx @@ -0,0 +1,152 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-03-16 15:00:07 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-27 15:23:14 + */ +import { type FC, useRef, useState, useEffect } from 'react' +import { Flex, Dropdown, type MenuProps, Slider } from 'antd' +import clsx from 'clsx' +import { useTranslation } from 'react-i18next' + +/** Available playback speed options. */ +const SPEEDS = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2] + +/** Format seconds into "MM:SS" display string. */ +const fmt = (s: number) => `${String(Math.floor(s / 60)).padStart(2, '0')}:${String(Math.floor(s % 60)).padStart(2, '0')}` + +/** + * Props for the AudioPlayer component. + * @property src - Audio file URL to play. + * @property fileName - Display name shown beside the file icon. + * @property fileSize - Human-readable file size string (e.g. "3.2 MB"). + */ +interface AudioPlayerProps { + src: string + fileName?: string + fileSize?: string +} + +/** + * AudioPlayer – A compact inline audio player with playback controls. + * + * Displays file metadata (name & size), a play/pause toggle, a seekable + * progress slider, elapsed/total time, and a dropdown menu for downloading + * the file or changing playback speed. + * + * @example + * + */ +const AudioPlayer: FC = ({ src, fileName, fileSize }) => { + const { t } = useTranslation() + const audioRef = useRef(null) + const [playing, setPlaying] = useState(false) + const [current, setCurrent] = useState(0) + const [duration, setDuration] = useState(0) + const [speed, setSpeed] = useState(1) + + /* Bind native audio events to sync React state; re-binds when src changes. */ + useEffect(() => { + const audio = audioRef.current + if (!audio) return + const onTime = () => setCurrent(audio.currentTime) + const onMeta = () => setDuration(audio.duration) + const onEnd = () => setPlaying(false) + audio.addEventListener('timeupdate', onTime) + audio.addEventListener('loadedmetadata', onMeta) + audio.addEventListener('ended', onEnd) + return () => { + audio.removeEventListener('timeupdate', onTime) + audio.removeEventListener('loadedmetadata', onMeta) + audio.removeEventListener('ended', onEnd) + } + }, [src]) + + /** Toggle between play and pause. */ + const togglePlay = () => { + const audio = audioRef.current + if (!audio) return + if (playing) { audio.pause(); setPlaying(false) } + else { audio.play(); setPlaying(true) } + } + + /** Seek to a specific position (in seconds) on the audio timeline. */ + const handleSeek = (val: number) => { + if (audioRef.current) audioRef.current.currentTime = val + setCurrent(val) + } + + /** Update playback speed on both React state and the native audio element. */ + const setPlaybackSpeed = (s: number) => { + setSpeed(s) + if (audioRef.current) audioRef.current.playbackRate = s + } + + /** Open the audio source URL in a new tab to trigger download. */ + const handleDownload = () => window.open(src, '_blank') + + /** Dropdown menu items: download and playback speed sub-menu. */ + const mainMenu: MenuProps = { + items: [ + { + key: 'download', + icon:
, + label: t('common.download'), + onClick: handleDownload, + }, + { + key: 'speed', + icon:
, + label: t('perceptualDetail.playbackSpeed'), + children: SPEEDS.map(s => ({ + key: String(s), + label: {s === 1 ? 'normal' : s}, + onClick: () => setPlaybackSpeed(s), + })), + }, + ], + } + + return ( +
+