From 79ab929fb0280626a92f9a8d263b86c363d0b5bb Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:01:57 +0800 Subject: [PATCH] Release/v0.2.3 (#355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(web): add PageEmpty component * feat(web): add PageTabs component * feat(web): add PageEmpty component * feat(web): add PageTabs component * feat(prompt): add history tracking for prompt releases * feat(web): add prompt menu * refactor: The PageScrollList component supports two generic parameters * feat(web): BodyWrapper compoent update PageLoading * feat(web): add Ontology menu * feat(web): memory management add scene * feat(tasks): add celery task configuration for periodic jobs - Add ignore_result=True to prevent storing results for periodic tasks - Set max_retries=0 to skip failed periodic tasks without retry attempts - Configure acks_late=False for immediate acknowledgment in beat tasks - Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s) - Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s) - Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s) - Improve task reliability and resource management for scheduled jobs * feat(sandbox): add Node.js code execution support to sandbox * Release/v0.2.2 (#260) * [modify] migration script * [add] migration script * fix(web): change form message * fix(web): the memoryContent field is compatible with numbers and strings * feat(web): code node hidden * fix(model): 1. create a basic model to check if the name and provider are duplicated. 2. The result shows error models because the provider created API Keys for all matching models. --------- Co-authored-by: Mark Co-authored-by: zhaoying Co-authored-by: yingzhao Co-authored-by: Timebomb2018 <18868801967@163.com> * Feature/ontology class clean (#249) * [add] Complete ontology engineering feature implementation * [add] Add ontology feature integration and validation utilities * [add] Add OWL validator and validation utilities * [fix] Add missing render_ontology_extraction_prompt function * [fix]Add dependencies, fix functionality * [add] migration script * feat(celery): add dedicated periodic tasks worker and queue (#261) * fix(web): conflict resolve * Fix/v022 bug (#263) * [fix]Fix the issue of inconsistent language in explicit and episodic memory. * [fix]Fix the issue of inconsistent language in explicit and episodic memory. * [add]Add scene_id * [fix]Based on the AI review to fix the code * Fix/develop memory reflex (#265) * 遗漏的历史映射 * 遗漏的历史映射 * 反思后台报错处理 * [add] migration script * fix: chat conversation_id add node_start * feat(web): show code node * fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated * feat(web): RadioGroupCard support block mode * feat(web): create space add icon * feat(app and model): token consumption statistics * Add/develop memory (#264) * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 遗漏的历史映射 * 新增长期记忆功能 * 新增长期记忆功能 * 新增长期记忆功能 * 知识库检索多余字段 * 长期 * feat(app and model): token consumption statistics of the cluster * memory_BUG_fix * fix(web): prompt history remove pageLoading * fix(prompt): remove hard-coded import of prompt file paths (#279) * Fix/develop memory bug (#274) * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix(web): update retrieve_type key * Fix/develop memory bug (#276) * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * chore(celery): disable periodic task scheduling * fix(prompt): remove hard-coded import of prompt file paths --------- Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Co-authored-by: zhaoying Co-authored-by: yingzhao Co-authored-by: Ke Sun * fix(web): remove delete confirm content * refactor(workflow): relocate template directory into workflow * feat(memory): add long-term storage task routing and batching * fix(web): PageScrollList loading update * fix(web): PageScrollList loading update * Ontology v1 bug (#291) * [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object * [fix]Fix the "end_user" return to be sorted by update time. * [fix]Set the default values of the memory configuration model based on the spatial model. * [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id * [fix]Fix the "end_user" return to be sorted by update time. * [fix] * fix(memory): add Redis session validation - Add macOS fork() safety configuration in celery_app.py to prevent initialization issues - Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully - Add null/False checks in memory_long_term_storage to prevent processing empty Redis results - Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data - Initialize redis_messages variable in window_dialogue for consistency - Add debug logging when no existing session found in Redis for better troubleshooting - Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants - Improve error handling when Redis returns False or empty results instead of crashing * fix(web): PageScrollList style update * fix(workflow): fix argument passing in code execution nodes * fix(web): prompt add disabled * fix(web): space icon required * feat(app): modify the key of the token * fix(fix the key of the app's token): * fix(workflow): switch code input encoding to base64+URL encoding * [add]The main project adds multi-API Key load balancing. * [changes]Attribute security access, secure numerical conversion, unified use of local variables * fix(web): save add session update * fix(web): language editor support paste * [changes]Active status filtering logic, API Key selection strategy * memory_BUG * memory_BUG_long_term * [changes] * memory_BUG_long_term * memory_BUG_long_term * Fix/release memory bug (#306) * memory_BUG_fix * memory_BUG * memory_BUG_long_term * memory_BUG_long_term * memory_BUG_long_term * knowledge_retrieval/bug/fix * knowledge_retrieval/bug/fix * knowledge_retrieval/bug/fix * [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios * fix(web): replace code editor * [changes]Modify the description of the time for the recent event * [changes]Modify the code based on the AI review * feat(web): update memory config ontology api * fix(web): ui update * knowledge_retrieval/bug/fix * knowledge_retrieval/bug/fix * knowledge_retrieval/bug/fix * feat(workflow): add token usage statistics for question classifier and parameter extraction * feat(web): move prompt menu * Multiple independent transactions - single transaction * Multiple independent transactions - single transaction * Multiple independent transactions - single transaction * Multiple independent transactions - single transaction * Write Missing None (#321) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Fix/release memory bug (#324) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Fix/writer memory bug (#326) * [fix]Fix the bug * [fix]Fix the bug * [fix]Correct the direction indication. * fix(web): markdown table ui update * Fix/release memory bug (#332) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update * writer_dup_bug/fix --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Fix/fact summary (#333) * [fix]Disable the contents related to fact_summary * [fix]Disable the contents related to fact_summary * [fix]Modify the code based on the AI review * Fix/release memory bug (#335) * Write Missing None * Write Missing None * Write Missing None * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Write Missing None * redis update * redis update * redis update * redis update * writer_dup_bug/fix * writer_graph_bug/fix * writer_graph_bug/fix --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Revert "feat(web): move prompt menu" This reverts commit 9e6e8f50f8136fb8c963af34d9446dc49a237cad. * fix(web): ui update * fix(web): update text * fix(web): ui update * fix(model): change the "vl" model type of dashscope to "chat" * fix(model): change the "vl" model type of dashscope to "chat" --------- Co-authored-by: zhaoying Co-authored-by: Eternity <1533512157@qq.com> Co-authored-by: Mark Co-authored-by: yingzhao Co-authored-by: Timebomb2018 <18868801967@163.com> Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com> Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Co-authored-by: lixinyue <2569494688@qq.com> Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com> Co-authored-by: lanceyq <1982376970@qq.com> Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- api/app/celery_app.py | 86 +- api/app/controllers/__init__.py | 2 + .../memory_reflection_controller.py | 74 +- .../controllers/memory_storage_controller.py | 5 + api/app/controllers/ontology_controller.py | 1005 ++++++++++++++ .../controllers/ontology_secondary_routes.py | 611 +++++++++ .../prompt_optimizer_controller.py | 118 +- api/app/core/agent/langchain_agent.py | 183 +-- api/app/core/config.py | 5 + .../langgraph_graph/routing/write_router.py | 238 ++++ .../agent/langgraph_graph/tools/tool.py | 3 +- .../agent/langgraph_graph/tools/write_tool.py | 72 + .../agent/langgraph_graph/write_graph.py | 112 +- .../agent/models/write_aggregate_model.py | 28 + .../prompt/write_aggregate_judgment.jinja2 | 57 + api/app/core/memory/agent/utils/redis_base.py | 186 +++ api/app/core/memory/agent/utils/redis_tool.py | 793 ++++++++--- .../core/memory/agent/utils/write_tools.py | 56 +- api/app/core/memory/models/__init__.py | 9 + api/app/core/memory/models/graph_models.py | 3 +- api/app/core/memory/models/ontology_models.py | 135 ++ .../deduplication/deduped_and_disamb.py | 65 +- .../deduplication/entity_dedup_llm.py | 23 +- .../deduplication/second_layer_dedup.py | 3 +- .../extraction_orchestrator.py | 3 +- .../knowledge_extraction/__init__.py | 1 + .../knowledge_extraction/memory_summary.py | 74 +- .../ontology_extraction.py | 482 +++++++ .../triplet_extraction.py | 12 +- api/app/core/memory/utils/alias_utils.py | 4 +- .../core/memory/utils/prompt/prompt_utils.py | 62 +- .../utils/prompt/prompts/entity_dedup.jinja2 | 6 +- .../episodic_type_classification.jinja2 | 11 + .../prompt/prompts/extract_ontology.jinja2 | 210 +++ .../prompt/prompts/extract_triplet.jinja2 | 21 +- .../prompt/prompts/memory_summary.jinja2 | 18 +- .../core/memory/utils/validation/__init__.py | 10 + .../utils/validation/ontology_validator.py | 268 ++++ .../memory/utils/validation/owl_validator.py | 585 +++++++++ .../core/models/scripts/dashscope_models.yaml | 26 +- api/app/core/rag/nlp/search.py | 40 +- api/app/core/workflow/nodes/code/config.py | 2 +- api/app/core/workflow/nodes/code/node.py | 36 +- .../nodes/parameter_extractor/node.py | 13 + .../nodes/question_classifier/node.py | 13 + api/app/core/workflow/template_loader.py | 33 +- .../templates}/customer_service/template.yml | 0 .../templates}/data_processing/template.yml | 0 .../templates}/multi_step_qa/template.yml | 0 .../templates}/simple_qa/template.yml | 0 api/app/models/__init__.py | 4 + api/app/models/memory_config_model.py | 3 + api/app/models/ontology_class.py | 40 + api/app/models/ontology_scene.py | 43 + api/app/models/prompt_optimizer_model.py | 27 +- .../repositories/memory_config_repository.py | 29 +- api/app/repositories/neo4j/add_edges.py | 6 +- api/app/repositories/neo4j/add_nodes.py | 4 +- api/app/repositories/neo4j/cypher_queries.py | 76 +- api/app/repositories/neo4j/graph_saver.py | 158 ++- .../repositories/ontology_class_repository.py | 404 ++++++ .../repositories/ontology_scene_repository.py | 439 +++++++ .../prompt_optimizer_repository.py | 217 ++- api/app/schemas/app_schema.py | 4 +- api/app/schemas/memory_agent_schema.py | 14 +- api/app/schemas/memory_storage_schema.py | 10 +- api/app/schemas/ontology_schemas.py | 461 +++++++ api/app/schemas/prompt_optimizer_schema.py | 17 + api/app/services/app_chat_service.py | 65 +- api/app/services/app_statistics_service.py | 4 +- api/app/services/conversation_service.py | 14 +- api/app/services/draft_run_service.py | 41 +- api/app/services/handoffs_service.py | 18 +- api/app/services/memory_dashboard_service.py | 10 +- api/app/services/memory_reflection_service.py | 28 +- api/app/services/memory_storage_service.py | 17 +- api/app/services/multi_agent_orchestrator.py | 20 +- api/app/services/multi_agent_service.py | 106 +- api/app/services/ontology_service.py | 1162 +++++++++++++++++ api/app/services/prompt_optimizer_service.py | 178 ++- api/app/services/shared_chat_service.py | 21 +- api/app/services/user_memory_service.py | 18 +- api/app/tasks.py | 320 ++++- api/app/utils/config_utils.py | 58 +- api/app/version_info.json | 28 + api/docker-compose.yml | 17 +- api/env.example | 5 + .../versions/550c10595967_202601301521.py | 78 ++ .../versions/9def72f79398_202601301850.py | 30 + api/pyproject.toml | 1 + sandbox/Dockerfile | 19 +- sandbox/app/__init__.py | 4 + sandbox/app/config.py | 118 +- sandbox/app/controllers/health_controller.py | 2 +- sandbox/app/controllers/sandbox_controller.py | 14 +- sandbox/app/core/runners/__init__.py | 39 + sandbox/app/core/runners/nodejs/__init__.py | 3 + sandbox/app/core/runners/nodejs/env.py | 124 ++ .../app/core/runners/nodejs/nodejs_runner.py | 138 ++ sandbox/app/core/runners/nodejs/prescript.js | 31 + sandbox/app/core/runners/python/__init__.py | 7 +- sandbox/app/core/runners/python/env.py | 70 +- sandbox/app/core/runners/python/prescript.py | 7 +- .../app/core/runners/python/python_runner.py | 8 +- sandbox/app/core/runners/python/settings.py | 62 - sandbox/app/dependencies.py | 8 +- sandbox/app/logger.py | 24 +- sandbox/app/middleware/concurrency.py | 94 +- sandbox/app/services/nodejs_service.py | 43 + sandbox/config.yaml | 10 +- .../nodejs/node_modules/.package-lock.json | 6 + sandbox/dependencies/nodejs/package-lock.json | 6 + sandbox/dependencies/nodejs/package.json | 1 + .../{ => python}/python-requirements.txt | 0 sandbox/lib/seccomp_nodejs/Cargo.lock | 7 - sandbox/lib/seccomp_nodejs/Cargo.toml | 6 - sandbox/lib/seccomp_nodejs/src/lib.rs | 0 .../Cargo.lock | 4 +- .../Cargo.toml | 11 +- .../src/lib.rs | 37 +- .../seccomp_redbear/src/nodejs_syscalls.rs | 74 ++ .../src/python_syscalls.rs} | 38 +- sandbox/main.py | 85 +- web/package.json | 10 + web/src/api/ontology.ts | 40 + web/src/api/prompt.ts | 19 +- web/src/assets/images/menu/ontology.svg | 11 + .../assets/images/menu/ontology_active.svg | 11 + web/src/assets/images/menu/prompt.svg | 15 + web/src/assets/images/menu/prompt_active.svg | 15 + web/src/assets/images/space/neo4j.png | Bin 0 -> 1424 bytes web/src/assets/images/space/rag.png | Bin 0 -> 1719 bytes web/src/components/CodeMirrorEditor/index.tsx | 150 +++ web/src/components/CustomSelect/index.tsx | 7 +- web/src/components/Empty/BodyWrapper.tsx | 8 +- web/src/components/Markdown/index.tsx | 6 +- web/src/components/PageScrollList/index.tsx | 35 +- web/src/components/RadioGroupCard/index.tsx | 39 +- web/src/components/SiderMenu/index.tsx | 44 +- web/src/i18n/en.ts | 61 +- web/src/i18n/zh.ts | 59 +- web/src/routes/index.tsx | 4 + web/src/routes/routes.json | 5 +- web/src/store/menu.json | 30 + web/src/styles/index.css | 5 + web/src/views/ApiKeyManagement/index.tsx | 5 +- .../components/AiPromptModal.tsx | 4 +- .../components/Editor/index.tsx | 11 +- .../Editor/plugin/EditablePlugin.tsx | 48 + .../components/Knowledge/Knowledge.tsx | 2 +- .../Knowledge/KnowledgeListModal.tsx | 2 +- .../components/ModelConfigModal.tsx | 8 +- web/src/views/ApplicationManagement/index.tsx | 20 +- web/src/views/ApplicationManagement/types.ts | 3 + web/src/views/Conversation/index.tsx | 1 + .../components/MemoryForm.tsx | 18 + web/src/views/MemoryManagement/index.tsx | 15 +- web/src/views/MemoryManagement/types.ts | 3 + .../components/ModelImplement/index.tsx | 1 - .../components/MultiKeyConfigModal.tsx | 6 +- .../components/OntologyClassExtractModal.tsx | 173 +++ .../components/OntologyClassModal.tsx | 96 ++ .../Ontology/components/OntologyModal.tsx | 99 ++ .../views/Ontology/components/PageHeader.tsx | 45 + web/src/views/Ontology/index.tsx | 133 ++ web/src/views/Ontology/pages/Detail.tsx | 122 ++ web/src/views/Ontology/types.ts | 79 ++ web/src/views/Prompt/History.tsx | 95 ++ web/src/views/Prompt/Prompt.tsx | 228 ++++ .../views/Prompt/components/PromptDetail.tsx | 82 ++ .../Prompt/components/PromptSaveModal.tsx | 90 ++ .../Prompt/components/PromptVariableModal.tsx | 104 ++ web/src/views/Prompt/index.tsx | 59 + web/src/views/Prompt/types.ts | 35 + .../SpaceManagement/components/SpaceModal.tsx | 200 +-- web/src/views/SpaceManagement/index.tsx | 11 +- web/src/views/SpaceManagement/types.ts | 9 +- .../Workflow/components/Editor/index.tsx | 12 +- .../components/Editor/plugin/BlurPlugin.tsx | 6 + .../Editor/plugin/InitialValuePlugin.tsx | 12 +- .../plugin/JavaScriptHighlightPlugin.tsx | 164 --- .../Editor/plugin/Python3HighlightPlugin.tsx | 159 --- .../Properties/CodeExecution/index.tsx | 15 +- .../Properties/Knowledge/Knowledge.tsx | 2 +- .../Knowledge/KnowledgeListModal.tsx | 2 +- web/src/views/Workflow/constant.ts | 52 +- .../views/Workflow/hooks/useWorkflowGraph.ts | 6 +- 187 files changed, 12252 insertions(+), 1656 deletions(-) create mode 100644 api/app/controllers/ontology_controller.py create mode 100644 api/app/controllers/ontology_secondary_routes.py create mode 100644 api/app/core/memory/agent/langgraph_graph/routing/write_router.py create mode 100644 api/app/core/memory/agent/langgraph_graph/tools/write_tool.py create mode 100644 api/app/core/memory/agent/models/write_aggregate_model.py create mode 100644 api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 create mode 100644 api/app/core/memory/agent/utils/redis_base.py create mode 100644 api/app/core/memory/models/ontology_models.py create mode 100644 api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py create mode 100644 api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 create mode 100644 api/app/core/memory/utils/validation/__init__.py create mode 100644 api/app/core/memory/utils/validation/ontology_validator.py create mode 100644 api/app/core/memory/utils/validation/owl_validator.py rename api/app/{templates/workflows => core/workflow/templates}/customer_service/template.yml (100%) rename api/app/{templates/workflows => core/workflow/templates}/data_processing/template.yml (100%) rename api/app/{templates/workflows => core/workflow/templates}/multi_step_qa/template.yml (100%) rename api/app/{templates/workflows => core/workflow/templates}/simple_qa/template.yml (100%) create mode 100644 api/app/models/ontology_class.py create mode 100644 api/app/models/ontology_scene.py create mode 100644 api/app/repositories/ontology_class_repository.py create mode 100644 api/app/repositories/ontology_scene_repository.py create mode 100644 api/app/schemas/ontology_schemas.py create mode 100644 api/app/services/ontology_service.py create mode 100644 api/migrations/versions/550c10595967_202601301521.py create mode 100644 api/migrations/versions/9def72f79398_202601301850.py create mode 100644 sandbox/app/__init__.py create mode 100644 sandbox/app/core/runners/nodejs/__init__.py create mode 100644 sandbox/app/core/runners/nodejs/env.py create mode 100644 sandbox/app/core/runners/nodejs/nodejs_runner.py create mode 100644 sandbox/app/core/runners/nodejs/prescript.js delete mode 100644 sandbox/app/core/runners/python/settings.py create mode 100644 sandbox/app/services/nodejs_service.py create mode 100644 sandbox/dependencies/nodejs/node_modules/.package-lock.json create mode 100644 sandbox/dependencies/nodejs/package-lock.json create mode 100644 sandbox/dependencies/nodejs/package.json rename sandbox/dependencies/{ => python}/python-requirements.txt (100%) delete mode 100644 sandbox/lib/seccomp_nodejs/Cargo.lock delete mode 100644 sandbox/lib/seccomp_nodejs/Cargo.toml delete mode 100644 sandbox/lib/seccomp_nodejs/src/lib.rs rename sandbox/lib/{seccomp_python => seccomp_redbear}/Cargo.lock (92%) rename sandbox/lib/{seccomp_python => seccomp_redbear}/Cargo.toml (51%) rename sandbox/lib/{seccomp_python => seccomp_redbear}/src/lib.rs (82%) create mode 100644 sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs rename sandbox/lib/{seccomp_python/src/syscalls.rs => seccomp_redbear/src/python_syscalls.rs} (90%) create mode 100644 web/src/api/ontology.ts create mode 100644 web/src/assets/images/menu/ontology.svg create mode 100644 web/src/assets/images/menu/ontology_active.svg create mode 100644 web/src/assets/images/menu/prompt.svg create mode 100644 web/src/assets/images/menu/prompt_active.svg create mode 100644 web/src/assets/images/space/neo4j.png create mode 100644 web/src/assets/images/space/rag.png create mode 100644 web/src/components/CodeMirrorEditor/index.tsx create mode 100644 web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx create mode 100644 web/src/views/Ontology/components/OntologyClassExtractModal.tsx create mode 100644 web/src/views/Ontology/components/OntologyClassModal.tsx create mode 100644 web/src/views/Ontology/components/OntologyModal.tsx create mode 100644 web/src/views/Ontology/components/PageHeader.tsx create mode 100644 web/src/views/Ontology/index.tsx create mode 100644 web/src/views/Ontology/pages/Detail.tsx create mode 100644 web/src/views/Ontology/types.ts create mode 100644 web/src/views/Prompt/History.tsx create mode 100644 web/src/views/Prompt/Prompt.tsx create mode 100644 web/src/views/Prompt/components/PromptDetail.tsx create mode 100644 web/src/views/Prompt/components/PromptSaveModal.tsx create mode 100644 web/src/views/Prompt/components/PromptVariableModal.tsx create mode 100644 web/src/views/Prompt/index.tsx create mode 100644 web/src/views/Prompt/types.ts delete mode 100644 web/src/views/Workflow/components/Editor/plugin/JavaScriptHighlightPlugin.tsx delete mode 100644 web/src/views/Workflow/components/Editor/plugin/Python3HighlightPlugin.tsx diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 185d746c..002547f6 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -3,9 +3,14 @@ import platform from datetime import timedelta from urllib.parse import quote -from app.core.config import settings from celery import Celery +from app.core.config import settings + +# macOS fork() safety - must be set before any Celery initialization +if platform.system() == 'Darwin': + os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') + # 创建 Celery 应用实例 # broker: 任务队列(使用 Redis DB 0) # backend: 结果存储(使用 Redis DB 10) @@ -63,15 +68,20 @@ celery_app.conf.update( 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) + 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, - # Beat/periodic tasks → document_tasks queue (prefork worker) - 'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'}, - 'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'}, - 'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'}, - 'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'}, + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) + 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, + 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, + 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, + 'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'}, }, ) @@ -79,40 +89,40 @@ celery_app.conf.update( celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) -memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) -workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME -forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 +# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) +# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME +# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 # 构建定时任务配置 -beat_schedule_config = { - "run-workspace-reflection": { - "task": "app.tasks.workspace_reflection_task", - "schedule": workspace_reflection_schedule, - "args": (), - }, - "regenerate-memory-cache": { - "task": "app.tasks.regenerate_memory_cache", - "schedule": memory_cache_regeneration_schedule, - "args": (), - }, - "run-forgetting-cycle": { - "task": "app.tasks.run_forgetting_cycle_task", - "schedule": forgetting_cycle_schedule, - "kwargs": { - "config_id": None, # 使用默认配置,可以通过环境变量配置 - }, - }, -} +# beat_schedule_config = { +# "run-workspace-reflection": { +# "task": "app.tasks.workspace_reflection_task", +# "schedule": workspace_reflection_schedule, +# "args": (), +# }, +# "regenerate-memory-cache": { +# "task": "app.tasks.regenerate_memory_cache", +# "schedule": memory_cache_regeneration_schedule, +# "args": (), +# }, +# "run-forgetting-cycle": { +# "task": "app.tasks.run_forgetting_cycle_task", +# "schedule": forgetting_cycle_schedule, +# "kwargs": { +# "config_id": None, # 使用默认配置,可以通过环境变量配置 +# }, +# }, +# } # 如果配置了默认工作空间ID,则添加记忆总量统计任务 -if settings.DEFAULT_WORKSPACE_ID: - beat_schedule_config["write-total-memory"] = { - "task": "app.controllers.memory_storage_controller.search_all", - "schedule": memory_increment_schedule, - "kwargs": { - "workspace_id": settings.DEFAULT_WORKSPACE_ID, - }, - } +# if settings.DEFAULT_WORKSPACE_ID: +# beat_schedule_config["write-total-memory"] = { +# "task": "app.controllers.memory_storage_controller.search_all", +# "schedule": memory_increment_schedule, +# "kwargs": { +# "workspace_id": settings.DEFAULT_WORKSPACE_ID, +# }, +# } -celery_app.conf.beat_schedule = beat_schedule_config +# celery_app.conf.beat_schedule = beat_schedule_config diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 3701f14d..765ef967 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -45,6 +45,7 @@ from . import ( home_page_controller, memory_perceptual_controller, memory_working_controller, + ontology_controller, ) # 创建管理端 API 路由器 @@ -90,5 +91,6 @@ manager_router.include_router(implicit_memory_controller.router) manager_router.include_router(memory_perceptual_controller.router) manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) +manager_router.include_router(ontology_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 7941be35..8d5408f1 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -51,7 +51,6 @@ async def save_reflection_config( status_code=status.HTTP_400_BAD_REQUEST, detail="缺少必需参数: config_id" ) - api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") memory_config = MemoryConfigRepository.update_reflection_config( @@ -102,7 +101,7 @@ async def start_workspace_reflection( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Activate the reflection function for all matching applications in the workspace""" + """启动工作空间中所有匹配应用的反思功能""" workspace_id = current_user.current_workspace_id reflection_service = MemoryReflectionService(db) @@ -111,42 +110,55 @@ async def start_workspace_reflection( service = WorkspaceAppService(db) result = service.get_workspace_apps_detailed(workspace_id) - reflection_results = [] - for data in result['apps_detailed_info']: - if data['memory_configs'] == []: + # 跳过没有配置的应用 + if not data['memory_configs']: + api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过") continue - + releases = data['releases'] memory_configs = data['memory_configs'] end_users = data['end_users'] - - for base, config, user in zip(releases, memory_configs, end_users): - # 安全地转换为整数,处理空字符串和None的情况 - print(base['config']) - try: - base_config = int(base['config']) if base['config'] else 0 - config_id = int(config['config_id']) if config['config_id'] else 0 - except (ValueError, TypeError): - api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}") + + # 为每个配置和用户组合执行反思 + for config in memory_configs: + config_id_str = str(config['config_id']) + + # 找到匹配此配置的所有release + matching_releases = [r for r in releases if str(r['config']) == config_id_str] + + if not matching_releases: + api_logger.debug(f"配置 {config_id_str} 没有匹配的release") continue - - if base_config == config_id and base['app_id'] == user['app_id']: - # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - - reflection_result = await reflection_service.start_text_reflection( - config_data=config, - end_user_id=user['id'] - ) - - reflection_results.append({ - "app_id": base['app_id'], - "config_id": config['config_id'], - "end_user_id": user['id'], - "reflection_result": reflection_result - }) + + # 为每个用户执行反思 + for user in end_users: + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}") + + try: + reflection_result = await reflection_service.start_text_reflection( + config_data=config, + end_user_id=user['id'] + ) + + reflection_results.append({ + "app_id": data['id'], + "config_id": config_id_str, + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + except Exception as e: + api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}") + reflection_results.append({ + "app_id": data['id'], + "config_id": config_id_str, + "end_user_id": user['id'], + "reflection_result": { + "status": "错误", + "message": f"反思失败: {str(e)}" + } + }) return success(data=reflection_results, msg="反思配置成功") diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index ae372d3b..0b627775 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -195,6 +195,11 @@ def update_config( api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 校验至少有一个字段需要更新 + if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py new file mode 100644 index 00000000..4e244e35 --- /dev/null +++ b/api/app/controllers/ontology_controller.py @@ -0,0 +1,1005 @@ +"""本体提取API控制器 + +本模块提供本体提取系统的RESTful API端点。 + +Endpoints: + POST /api/memory/ontology/extract - 提取本体类 + POST /api/memory/ontology/export - 导出OWL文件 + POST /api/memory/ontology/scene - 创建本体场景 + PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景 + DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景 + GET /api/memory/ontology/scene/{scene_id} - 获取单个场景 + GET /api/memory/ontology/scenes - 获取场景列表 + POST /api/memory/ontology/class - 创建本体类型 + PUT /api/memory/ontology/class/{class_id} - 更新本体类型 + DELETE /api/memory/ontology/class/{class_id} - 删除本体类型 + GET /api/memory/ontology/class/{class_id} - 获取单个类型 + GET /api/memory/ontology/classes - 获取类型列表 +""" + +import logging +import tempfile +from typing import Dict, Optional + +from fastapi import APIRouter, Depends, HTTPException, Header +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.services.memory_base_service import Translation_English +from app.core.memory.models.ontology_models import OntologyClass +from typing import List +from app.schemas.ontology_schemas import ( + ExportRequest, + ExportResponse, + ExtractionRequest, + ExtractionResponse, + SceneCreateRequest, + SceneUpdateRequest, + SceneResponse, + SceneListResponse, + ClassCreateRequest, + ClassUpdateRequest, + ClassResponse, + ClassListResponse, +) +from app.schemas.response_schema import ApiResponse +from app.services.ontology_service import OntologyService +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.utils.validation.owl_validator import OWLValidator +from app.services.model_service import ModelConfigService +from app.repositories.ontology_scene_repository import OntologySceneRepository + + +api_logger = get_api_logger() +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/memory/ontology", + tags=["Ontology"], +) + + +async def translate_ontology_classes( + classes: List[OntologyClass], + model_id: str +) -> List[OntologyClass]: + """翻译本体类列表 + + 将本体类的中文字段翻译为英文,包括: + - name_chinese: 中文名称 + - description: 描述 + - examples: 示例列表 + + Args: + classes: 本体类列表 + model_id: LLM模型ID,用于翻译 + + Returns: + List[OntologyClass]: 翻译后的本体类列表 + """ + translated_classes = [] + + for ontology_class in classes: + # 创建类的副本,避免修改原对象 + translated_class = ontology_class.model_copy(deep=True) + + # 翻译 name_chinese 字段 + if translated_class.name_chinese: + try: + translated_class.name_chinese = await Translation_English( + model_id, + translated_class.name_chinese + ) + except Exception as e: + logger.warning(f"Failed to translate name_chinese: {e}") + # 保留原文 + + # 翻译 description 字段 + if translated_class.description: + try: + translated_class.description = await Translation_English( + model_id, + translated_class.description + ) + except Exception as e: + logger.warning(f"Failed to translate description: {e}") + # 保留原文 + + # 翻译 examples 列表 + if translated_class.examples: + translated_examples = [] + for example in translated_class.examples: + try: + translated_example = await Translation_English( + model_id, + example + ) + translated_examples.append(translated_example) + except Exception as e: + logger.warning(f"Failed to translate example: {e}") + translated_examples.append(example) # 保留原文 + translated_class.examples = translated_examples + + translated_classes.append(translated_class) + + return translated_classes + + +def _get_ontology_service( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + llm_id: str = None +) -> OntologyService: + """获取OntologyService实例的依赖注入函数 + + 指定的llm_id获取LLM配置,创建OpenAIClient和OntologyService实例。 + + Args: + db: 数据库会话 + current_user: 当前用户 + llm_id: 可选的LLM模型ID,如果提供则使用指定模型,否则使用工作空间默认模型 + + Returns: + OntologyService: 本体提取服务实例 + + Raises: + HTTPException: 如果无法获取LLM配置 + """ + try: + import uuid + + # 必须提供llm_id + if not llm_id: + logger.error(f"llm_id is required but not provided - user: {current_user.id}") + raise HTTPException( + status_code=400, + detail="必须提供llm_id参数" + ) + + logger.info(f"Using specified LLM model: {llm_id}") + + # 验证llm_id格式 + try: + model_id = uuid.UUID(llm_id) + except ValueError: + logger.error(f"Invalid llm_id format: {llm_id}") + raise HTTPException( + status_code=400, + detail="无效的LLM模型ID格式" + ) + + # 获取指定的模型配置 + try: + model_config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) + except Exception as e: + logger.error(f"Model {llm_id} not found: {str(e)}") + raise HTTPException( + status_code=400, + detail=f"找不到指定的LLM模型: {llm_id}" + ) + + # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) + from app.repositories.model_repository import ModelApiKeyRepository + api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + if not api_keys: + logger.error(f"Model {llm_id} has no active API key") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有可用的API密钥" + ) + api_key_config = api_keys[0] + + is_composite = getattr(model_config, 'is_composite', False) + logger.info( + f"Using specified model - user: {current_user.id}, " + f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " + f"is_composite: {is_composite}, api_key_id: {api_key_config.id}" + ) + + # 创建模型配置对象 + from app.core.models.base import RedBearModelConfig + + # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider + actual_provider = api_key_config.provider if is_composite else ( + getattr(model_config, 'provider', None) or "openai" + ) + + llm_model_config = RedBearModelConfig( + model_name=api_key_config.model_name, + provider=actual_provider, + api_key=api_key_config.api_key, + base_url=api_key_config.api_base, + max_retries=3, + timeout=60.0 + ) + + # 创建OpenAI客户端 + llm_client = OpenAIClient(model_config=llm_model_config) + + # 创建OntologyService + service = OntologyService(llm_client=llm_client, db=db) + + logger.debug( + f"OntologyService created successfully - " + f"user: {current_user.id}, model: {api_key_config.model_name}" + ) + + return service + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create OntologyService: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"创建本体提取服务失败: {str(e)}" + ) + + +@router.post("/extract", response_model=ApiResponse) +async def extract_ontology( + request: ExtractionRequest, + language_type: str = Header(default="zh", alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """提取本体类 + + 从场景描述中提取符合OWL规范的本体类。 + 提取结果仅返回给前端,不会自动保存到数据库。 + 前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。 + 支持中英文切换,通过 X-Language-Type Header 指定语言。 + + Args: + request: 提取请求,包含scenario、domain、llm_id和scene_id + language_type: 语言类型,'zh'(中文)或 'en'(英文),默认 'zh' + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含提取结果的响应 + + Response format: + { + "code": 200, + "msg": "本体提取成功", + "data": { + "classes": [ + { + "id": "147d9db50b524a9e909e01a753d3acdd", + "name": "Patient", + "name_chinese": "患者", + "description": "在医疗机构中接受诊疗、护理或健康管理的个体", + "examples": ["糖尿病患者", "术后康复患者", "门诊初诊患者"], + "parent_class": null, + "entity_type": "Person", + "domain": "Healthcare" + }, + ... + ], + "domain": "Healthcare", + "extracted_count": 7 + } + } + """ + api_logger.info( + f"Ontology extraction requested by user {current_user.id}, " + f"scenario_length={len(request.scenario)}, " + f"domain={request.domain}, " + f"llm_id={request.llm_id}, " + f"scene_id={request.scene_id}, " + f"language_type={language_type}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例,传入llm_id + service = _get_ontology_service( + db=db, + current_user=current_user, + llm_id=request.llm_id + ) + + # 调用服务层执行提取,传入scene_id和workspace_id + result = await service.extract_ontology( + scenario=request.scenario, + domain=request.domain, + scene_id=request.scene_id, + workspace_id=workspace_id + ) + + # ===== 新增:翻译逻辑 ===== + # 如果需要英文,则翻译数据 + if language_type != 'zh': + api_logger.info(f"Translating extraction result to English") + + # 翻译 classes 列表 + result.classes = await translate_ontology_classes( + result.classes, + request.llm_id + ) + + # 翻译 domain 字段 + if result.domain: + try: + result.domain = await Translation_English( + request.llm_id, + result.domain + ) + except Exception as e: + logger.warning(f"Failed to translate domain: {e}") + # 保留原文 + # ===== 翻译逻辑结束 ===== + + # 构建响应 + response = ExtractionResponse( + classes=result.classes, + domain=result.domain, + extracted_count=len(result.classes) + ) + + api_logger.info( + f"Ontology extraction completed, extracted {len(result.classes)} classes, " + f"saved to scene {request.scene_id}, language={language_type}" + ) + + return success(data=response.model_dump(), msg="本体提取成功") + + except ValueError as e: + # 验证错误 (400) + api_logger.warning(f"Validation error in extraction: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + # 运行时错误 (500) + api_logger.error(f"Runtime error in extraction: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e)) + + except Exception as e: + # 未知错误 (500) + api_logger.error(f"Unexpected error in extraction: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e)) + + +@router.post("/export", response_model=ApiResponse) +async def export_owl( + request: ExportRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导出OWL文件 + + 将提取的本体类导出为OWL文件,支持多种格式。 + 导出操作不需要LLM,只使用OWL验证器和Owlready2库。 + + Args: + request: 导出请求,包含classes、format和include_metadata + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含OWL文件内容的响应 + + Supported formats: + - rdfxml: 标准OWL RDF/XML格式(完整) + - turtle: Turtle格式(可读性好) + - ntriples: N-Triples格式(简单) + - json: JSON格式(简化,只包含类信息) + + Response format: + { + "code": 200, + "msg": "OWL文件导出成功", + "data": { + "owl_content": "...", + "format": "rdfxml", + "classes_count": 7 + } + } + """ + api_logger.info( + f"OWL export requested by user {current_user.id}, " + f"classes_count={len(request.classes)}, " + f"format={request.format}, " + f"include_metadata={request.include_metadata}" + ) + + try: + # 验证格式 + valid_formats = ["rdfxml", "turtle", "ntriples", "json"] + if request.format not in valid_formats: + api_logger.warning(f"Invalid export format: {request.format}") + return fail( + BizCode.BAD_REQUEST, + "不支持的导出格式", + f"format必须是以下之一: {', '.join(valid_formats)}" + ) + + # JSON格式直接导出,不需要OWL验证 + if request.format == "json": + owl_validator = OWLValidator() + owl_content = owl_validator.export_to_owl( + world=None, + format="json", + classes=request.classes + ) + + response = ExportResponse( + owl_content=owl_content, + format=request.format, + classes_count=len(request.classes) + ) + + api_logger.info( + f"JSON export completed, content_length={len(owl_content)}" + ) + + return success(data=response.model_dump(), msg="OWL文件导出成功") + + # 创建临时文件路径 + with tempfile.NamedTemporaryFile( + mode='w', + suffix='.owl', + delete=False + ) as tmp_file: + output_path = tmp_file.name + + # 导出操作不需要LLM,直接使用OWL验证器 + owl_validator = OWLValidator() + + # 验证本体类 + logger.debug("Validating ontology classes") + is_valid, errors, world = owl_validator.validate_ontology_classes( + classes=request.classes, + ) + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues during export: {errors}" + ) + # 继续导出,但记录警告 + + if not world: + error_msg = "Failed to create OWL world for export" + logger.error(error_msg) + return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg) + + # 导出OWL文件 + logger.info(f"Exporting to {request.format} format") + owl_content = owl_validator.export_to_owl( + world=world, + output_path=output_path, + format=request.format, + classes=request.classes + ) + + # 构建响应 + response = ExportResponse( + owl_content=owl_content, + format=request.format, + classes_count=len(request.classes) + ) + + api_logger.info( + f"OWL export completed, format={request.format}, " + f"content_length={len(owl_content)}" + ) + + return success(data=response.model_dump(), msg="OWL文件导出成功") + + except ValueError as e: + # 验证错误 (400) + api_logger.warning(f"Validation error in export: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + # 运行时错误 (500) + api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e)) + + except Exception as e: + # 未知错误 (500) + api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e)) + + +# ==================== 本体场景管理接口 ==================== + +@router.post("/scene", response_model=ApiResponse) +async def create_scene( + request: SceneCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体场景 + + 在当前工作空间下创建新的本体场景。 + + Args: + request: 场景创建请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含创建的场景信息 + """ + api_logger.info( + f"Scene creation requested by user {current_user.id}, " + f"name={request.scene_name}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例(不需要LLM) + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + # 创建一个空的LLM配置(场景管理不需要LLM) + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层创建场景 + scene = service.create_scene( + scene_name=request.scene_name, + scene_description=request.scene_description, + workspace_id=workspace_id + ) + + # 构建响应 + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + response = SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + ) + + api_logger.info(f"Scene created successfully: {scene.scene_id}") + + return success(data=response.model_dump(), msg="场景创建成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene creation: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + + +@router.put("/scene/{scene_id}", response_model=ApiResponse) +async def update_scene( + scene_id: str, + request: SceneUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体场景 + + 更新指定场景的信息,只能更新当前工作空间下的场景。 + + Args: + scene_id: 场景ID + request: 场景更新请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含更新后的场景信息 + """ + api_logger.info( + f"Scene update requested by user {current_user.id}, " + f"scene_id={scene_id}" + ) + + try: + from uuid import UUID + + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例 + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层更新场景 + scene = service.update_scene( + scene_id=scene_uuid, + scene_name=request.scene_name, + scene_description=request.scene_description, + workspace_id=workspace_id + ) + + # 构建响应 + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + response = SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + ) + + api_logger.info(f"Scene updated successfully: {scene_id}") + + return success(data=response.model_dump(), msg="场景更新成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene update: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e)) + + +@router.delete("/scene/{scene_id}", response_model=ApiResponse) +async def delete_scene( + scene_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体场景 + + 删除指定场景及其所有关联类型,只能删除当前工作空间下的场景。 + + Args: + scene_id: 场景ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 删除结果 + """ + api_logger.info( + f"Scene deletion requested by user {current_user.id}, " + f"scene_id={scene_id}" + ) + + try: + from uuid import UUID + + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例 + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层删除场景 + success_flag = service.delete_scene( + scene_id=scene_uuid, + workspace_id=workspace_id + ) + + api_logger.info(f"Scene deleted successfully: {scene_id}") + + return success(data={"deleted": success_flag}, msg="场景删除成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene deletion: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) + + +@router.get("/scenes/simple", response_model=ApiResponse) +async def get_scenes_simple( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景简单列表(轻量级,用于下拉选择) + + 仅返回 scene_id 和 scene_name,不加载关联数据,响应速度快。 + 适用于前端下拉选择场景的场景。 + + Args: + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景简单列表 + + Examples: + GET /scenes/simple + 返回: {"data": [{"scene_id": "xxx", "scene_name": "场景1"}, ...]} + """ + api_logger.info(f"Simple scene list requested by user {current_user.id}") + + try: + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + repo = OntologySceneRepository(db) + scenes = repo.get_simple_list(workspace_id) + + api_logger.info(f"Simple scene list retrieved: {len(scenes)} scenes") + return success(data=scenes, msg="查询成功") + + except Exception as e: + api_logger.error(f"Failed to get simple scene list: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + +@router.get("/scenes", response_model=ApiResponse) +async def get_scenes( + workspace_id: Optional[str] = None, + scene_name: Optional[str] = None, + page: Optional[int] = None, + pagesize: Optional[int] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景列表(支持模糊搜索和全量查询,全量查询支持分页) + + 根据是否提供 scene_name 参数,执行不同的查询: + - 提供 scene_name:进行模糊搜索,返回匹配的场景列表(支持分页) + - 不提供 scene_name:返回工作空间下的所有场景(支持分页) + + 支持中文和英文的模糊匹配,不区分大小写。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + scene_name: 场景名称关键词(可选,支持模糊匹配) + page: 页码(可选,从1开始) + pagesize: 每页数量(可选) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景列表和分页信息 + + Examples: + - 模糊搜索(不分页):GET /scenes?workspace_id=xxx&scene_name=医疗 + 输入 "医疗" 可以匹配到 "医疗场景"、"智慧医疗"、"医疗管理系统" 等 + - 模糊搜索(分页):GET /scenes?workspace_id=xxx&scene_name=医疗&page=1&pagesize=10 + 返回匹配 "医疗" 的第1页,每页10条数据 + - 全量查询(不分页):GET /scenes?workspace_id=xxx + 返回工作空间下的所有场景 + - 全量查询(分页):GET /scenes?workspace_id=xxx&page=1&pagesize=10 + 返回第1页,每页10条数据 + + Notes: + - 分页参数 page 和 pagesize 必须同时提供 + - page 从1开始,pagesize 必须大于0 + - 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}} + - 不分页时,page 字段为 null + """ + from app.controllers.ontology_secondary_routes import scenes_handler + return await scenes_handler(workspace_id, scene_name, page, pagesize, db, current_user) + + +# ==================== 本体类型管理接口 ==================== + +@router.post("/class", response_model=ApiResponse) +async def create_class( + request: ClassCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体类型 + + 在指定场景下创建新的本体类型。 + + Args: + request: 类型创建请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含创建的类型信息 + """ + from app.controllers.ontology_secondary_routes import create_class_handler + return await create_class_handler(request, db, current_user) + + +@router.put("/class/{class_id}", response_model=ApiResponse) +async def update_class( + class_id: str, + request: ClassUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体类型 + + 更新指定类型的信息,只能更新当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + request: 类型更新请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含更新后的类型信息 + """ + from app.controllers.ontology_secondary_routes import update_class_handler + return await update_class_handler(class_id, request, db, current_user) + + +@router.delete("/class/{class_id}", response_model=ApiResponse) +async def delete_class( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体类型 + + 删除指定类型,只能删除当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 删除结果 + """ + from app.controllers.ontology_secondary_routes import delete_class_handler + return await delete_class_handler(class_id, db, current_user) + + +@router.get("/classes", response_model=ApiResponse) +async def get_classes( + scene_id: str, + class_name: Optional[str] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取类型列表(支持模糊搜索和全量查询) + + 根据是否提供 class_name 参数,执行不同的查询: + - 提供 class_name:进行模糊搜索,返回匹配的类型列表 + - 不提供 class_name:返回场景下的所有类型 + + 支持中文和英文的模糊匹配,不区分大小写。 + 返回结果包含场景的基本信息(scene_name 和 scene_description)。 + + Args: + scene_id: 场景ID(必填) + class_name: 类型名称关键词(可选,支持模糊匹配) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含类型列表和场景信息 + + Examples: + - 模糊搜索:GET /classes?scene_id=xxx&class_name=患者 + 输入 "患者" 可以匹配到 "患者"、"患者信息"、"门诊患者" 等 + - 全量查询:GET /classes?scene_id=xxx + 返回场景下的所有类型 + + Response Format: + { + "total": 3, + "scene_id": "xxx", + "scene_name": "医疗场景", + "scene_description": "用于医疗领域的本体建模", + "items": [...] + } + """ + from app.controllers.ontology_secondary_routes import classes_handler + return await classes_handler(scene_id, class_name, db, current_user) + + +@router.get("/class/{class_id}", response_model=ApiResponse) +async def get_class( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取单个本体类型 + + 根据类型ID获取类型的详细信息,只能查询当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含类型详细信息 + + Response Format: + { + "code": 0, + "msg": "查询成功", + "data": { + "class_id": "xxx", + "class_name": "患者", + "class_description": "在医疗机构中接受诊疗的个体", + "scene_id": "xxx", + "created_at": "2026-01-29T10:00:00", + "updated_at": "2026-01-29T10:00:00" + } + } + """ + from app.controllers.ontology_secondary_routes import get_class_handler + return await get_class_handler(class_id, db, current_user) diff --git a/api/app/controllers/ontology_secondary_routes.py b/api/app/controllers/ontology_secondary_routes.py new file mode 100644 index 00000000..99017eea --- /dev/null +++ b/api/app/controllers/ontology_secondary_routes.py @@ -0,0 +1,611 @@ +# -*- coding: utf-8 -*- +"""本体场景和类型路由(续) + +由于主Controller文件较大,将剩余路由放在此文件中。 +""" + +from uuid import UUID +from typing import Optional + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.ontology_schemas import ( + SceneResponse, + SceneListResponse, + PaginationInfo, + ClassCreateRequest, + ClassUpdateRequest, + ClassResponse, + ClassListResponse, + ClassBatchCreateResponse, +) +from app.schemas.response_schema import ApiResponse +from app.services.ontology_service import OntologyService +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.models.base import RedBearModelConfig + + +api_logger = get_api_logger() + + +def _get_dummy_ontology_service(db: Session) -> OntologyService: + """获取OntologyService实例(不需要LLM) + + 场景和类型管理不需要LLM,创建一个dummy配置。 + """ + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + return OntologyService(llm_client=llm_client, db=db) + + +# 这些函数将被导入到主Controller中 + +async def scenes_handler( + workspace_id: Optional[str] = None, + scene_name: Optional[str] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景列表(支持模糊搜索和全量查询,全量查询支持分页) + + 当提供 scene_name 参数时,进行模糊搜索(不分页); + 当不提供 scene_name 参数时,返回所有场景(支持分页)。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + scene_name: 场景名称关键词(可选,支持模糊匹配) + page: 页码(可选,从1开始,仅在全量查询时有效) + page_size: 每页数量(可选,仅在全量查询时有效) + db: 数据库会话 + current_user: 当前用户 + """ + operation = "search" if scene_name else "list" + api_logger.info( + f"Scene {operation} requested by user {current_user.id}, " + f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" + ) + + try: + # 确定工作空间ID + if workspace_id: + try: + ws_uuid = UUID(workspace_id) + except ValueError: + api_logger.warning(f"Invalid workspace_id format: {workspace_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式") + else: + ws_uuid = current_user.current_workspace_id + if not ws_uuid: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 根据是否提供 scene_name 决定查询方式 + if scene_name and scene_name.strip(): + # 验证分页参数(模糊搜索也支持分页) + if page is not None and page < 1: + api_logger.warning(f"Invalid page number: {page}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") + + if page_size is not None and page_size < 1: + api_logger.warning(f"Invalid page_size: {page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") + + # 如果只提供了page或page_size中的一个,返回错误 + if (page is not None and page_size is None) or (page is None and page_size is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") + + # 模糊搜索场景(支持分页) + scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid) + total = len(scenes) + + # 如果提供了分页参数,进行分页处理 + if page is not None and page_size is not None: + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + scenes = scenes[start_idx:end_idx] + + # 构建响应 + items = [] + for scene in scenes: + # 获取前3个class_name作为entity_type + entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + items.append(SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + entity_type=entity_type, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + )) + + # 构建响应(包含分页信息) + if page is not None and page_size is not None: + # 计算是否有下一页 + hasnext = (page * page_size) < total + + pagination_info = PaginationInfo( + page=page, + pagesize=page_size, + total=total, + hasnext=hasnext + ) + response = SceneListResponse(items=items, page=pagination_info) + else: + response = SceneListResponse(items=items) + + api_logger.info( + f"Scene search completed: found {len(items)} scenes matching '{scene_name}' " + f"in workspace {ws_uuid}, total={total}" + ) + else: + # 获取所有场景(支持分页) + # 验证分页参数 + if page is not None and page < 1: + api_logger.warning(f"Invalid page number: {page}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") + + if page_size is not None and page_size < 1: + api_logger.warning(f"Invalid page_size: {page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") + + # 如果只提供了page或page_size中的一个,返回错误 + if (page is not None and page_size is None) or (page is None and page_size is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") + + scenes, total = service.list_scenes(ws_uuid, page, page_size) + + # 构建响应 + items = [] + for scene in scenes: + # 获取前3个class_name作为entity_type + entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + items.append(SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + entity_type=entity_type, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + )) + + # 构建响应(包含分页信息) + if page is not None and page_size is not None: + # 计算是否有下一页 + hasnext = (page * page_size) < total + + pagination_info = PaginationInfo( + page=page, + pagesize=page_size, + total=total, + hasnext=hasnext + ) + response = SceneListResponse(items=items, page=pagination_info) + else: + response = SceneListResponse(items=items) + + api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene {operation}: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + +# ==================== 本体类型管理接口 ==================== + +async def create_class_handler( + request: ClassCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体类型(统一使用列表形式,支持单个或批量)""" + + # 根据列表长度判断是单个还是批量 + count = len(request.classes) + mode = "single" if count == 1 else "batch" + + api_logger.info( + f"Class creation ({mode}) requested by user {current_user.id}, " + f"scene_id={request.scene_id}, count={count}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 准备类型数据 + classes_data = [ + { + "class_name": item.class_name, + "class_description": item.class_description + } + for item in request.classes + ] + + if count == 1: + # 单个创建 + class_data = classes_data[0] + ontology_class = service.create_class( + scene_id=request.scene_id, + class_name=class_data["class_name"], + class_description=class_data["class_description"], + workspace_id=workspace_id + ) + + # 构建单个响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class created successfully: {ontology_class.class_id}") + + return success(data=response.model_dump(mode='json'), msg="类型创建成功") + + else: + # 批量创建 + created_classes, errors = service.create_classes_batch( + scene_id=request.scene_id, + classes=classes_data, + workspace_id=workspace_id + ) + + # 构建批量响应 + items = [] + for ontology_class in created_classes: + items.append(ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + )) + + response = ClassBatchCreateResponse( + total=len(classes_data), + success_count=len(created_classes), + failed_count=len(errors), + items=items, + errors=errors if errors else None + ) + + api_logger.info( + f"Batch class creation completed: " + f"success={len(created_classes)}, failed={len(errors)}" + ) + + return success(data=response.model_dump(mode='json'), msg="批量创建完成") + + except ValueError as e: + api_logger.warning(f"Validation error in class creation: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + + +async def update_class_handler( + class_id: str, + request: ClassUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体类型""" + api_logger.info( + f"Class update requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 更新类型 + ontology_class = service.update_class( + class_id=class_uuid, + class_name=request.class_name, + class_description=request.class_description, + workspace_id=workspace_id + ) + + # 构建响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class updated successfully: {class_id}") + + return success(data=response.model_dump(mode='json'), msg="类型更新成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class update: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e)) + + +async def delete_class_handler( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体类型""" + api_logger.info( + f"Class deletion requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 删除类型 + success_flag = service.delete_class( + class_id=class_uuid, + workspace_id=workspace_id + ) + + api_logger.info(f"Class deleted successfully: {class_id}") + + return success(data={"deleted": success_flag}, msg="类型删除成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class deletion: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e)) + + +async def get_class_handler( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取单个本体类型""" + api_logger.info( + f"Get class requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 获取类型(会抛出ValueError如果不存在) + ontology_class = service.get_class_by_id(class_uuid, workspace_id) + + # 构建响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class retrieved successfully: {class_id}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + # 类型不存在或无权限访问 + api_logger.warning(f"Validation error in get class: {str(e)}") + return fail(BizCode.NOT_FOUND, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + +async def classes_handler( + scene_id: str, + class_name: Optional[str] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取类型列表(支持模糊搜索和全量查询) + + 当提供 class_name 参数时,进行模糊搜索; + 当不提供 class_name 参数时,返回场景下的所有类型。 + + Args: + scene_id: 场景ID(必填) + class_name: 类型名称关键词(可选,支持模糊匹配) + db: 数据库会话 + current_user: 当前用户 + """ + operation = "search" if class_name else "list" + api_logger.info( + f"Class {operation} requested by user {current_user.id}, " + f"keyword={class_name}, scene_id={scene_id}" + ) + + try: + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 获取场景信息 + scene = service.get_scene_by_id(scene_uuid, workspace_id) + if not scene: + api_logger.warning(f"Scene not found: {scene_id}") + return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景") + + # 根据是否提供 class_name 决定查询方式 + if class_name and class_name.strip(): + # 模糊搜索类型 + classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id) + else: + # 获取所有类型 + classes = service.list_classes_by_scene(scene_uuid, workspace_id) + + # 构建响应 + items = [] + for ontology_class in classes: + items.append(ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + )) + + response = ClassListResponse( + total=len(items), + scene_id=scene_uuid, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + items=items + ) + + if class_name: + api_logger.info( + f"Class search completed: found {len(items)} classes matching '{class_name}' " + f"in scene {scene_id}" + ) + else: + api_logger.info(f"Class list retrieved successfully, count={len(items)}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class {operation}: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index dba52d0b..61195deb 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -1,5 +1,5 @@ -import uuid import json +import uuid from fastapi import APIRouter, Depends, Path from sqlalchemy.orm import Session @@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.dependencies import get_current_user, get_db -from app.models.prompt_optimizer_model import RoleType -from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \ - OptimizePromptResponse, SessionHistoryResponse, SessionMessage +from app.schemas.prompt_optimizer_schema import ( + PromptOptMessage, + CreateSessionResponse, + SessionHistoryResponse, + SessionMessage, + PromptSaveRequest +) from app.schemas.response_schema import ApiResponse from app.services.prompt_optimizer_service import PromptOptimizerService @@ -135,3 +139,109 @@ async def get_prompt_opt( "X-Accel-Buffering": "no" } ) + + +@router.post( + "/releases", + summary="Get prompt optimization", + response_model=ApiResponse +) +def save_prompt( + data: PromptSaveRequest, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Save a prompt release for the current tenant. + + Args: + data (PromptSaveRequest): Request body containing session_id, title, and prompt. + db (Session): SQLAlchemy database session, injected via dependency. + current_user: Currently authenticated user object, injected via dependency. + + Returns: + ApiResponse: Standard API response containing the saved prompt release info: + - id: UUID of the prompt release + - session_id: associated session + - title: prompt title + - prompt: prompt content + - created_at: timestamp of creation + + Raises: + Any database or service exceptions are propagated to the global exception handler. + """ + service = PromptOptimizerService(db) + prompt_info = service.save_prompt( + tenant_id=current_user.tenant_id, + session_id=data.session_id, + title=data.title, + prompt=data.prompt + ) + return success(data=prompt_info) + + +@router.delete( + "/releases/{prompt_id}", + summary="Delete prompt (soft delete)", + response_model=ApiResponse +) +def delete_prompt( + prompt_id: uuid.UUID = Path(..., description="Prompt ID"), + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Soft delete a prompt release. + + Args: + prompt_id + db (Session): Database session + current_user: Current logged-in user + + Returns: + ApiResponse: Success message confirming deletion + """ + service = PromptOptimizerService(db) + service.delete_prompt( + tenant_id=current_user.tenant_id, + prompt_id=prompt_id + ) + return success(msg="Prompt deleted successfully") + + +@router.get( + "/releases/list", + summary="Get paginated list of released prompts with optional filter", + response_model=ApiResponse +) +def get_release_list( + page: int = 1, + page_size: int = 20, + keyword: str | None = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Retrieve paginated list of released prompts for the current tenant. + Optionally filter by keyword in title. + + Args: + page (int): Page number (starting from 1) + page_size (int): Number of items per page (max 100) + keyword (str | None): Optional keyword to filter prompt titles + db (Session): Database session + current_user: Current logged-in user + + Returns: + ApiResponse: Contains paginated list of prompt releases with metadata + """ + service = PromptOptimizerService(db) + result = service.get_release_list( + tenant_id=current_user.tenant_id, + page=max(1, page), + page_size=min(max(1, page_size), 100), + filter_keyword=keyword + ) + return success(data=result) + + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index a34c781f..e519ea53 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,29 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence - +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -106,7 +98,7 @@ class LangChainAgent: "streaming": streaming, "tool_count": len(self.tools), "tool_names": [tool.name for tool in self.tools] if self.tools else [], - "tool_count": len(self.tools) + # "tool_count": len(self.tools) } ) @@ -145,106 +137,8 @@ class LangChainAgent: user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}" messages.append(HumanMessage(content=user_content)) - return messages -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_save(self,messages,end_user_end,aimessages): - # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - # end_user_end=f"Term_{end_user_end}" - # print(messages) - # print(aimessages) - # session_id = store.save_session( - # userid=end_user_end, - # messages=messages, - # apply_id=end_user_end, - # end_user_id=end_user_end, - # aimessages=aimessages - # ) - # store.delete_duplicate_sessions() - # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - # return session_id -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_redis_read(self,end_user_end): - # end_user_end = f"Term_{end_user_end}" - # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # # logger.info(f'Redis_Agent:{end_user_end};{history}') - # messagss_list=[] - # retrieved_content=[] - # for messages in history: - # query = messages.get("Query") - # aimessages = messages.get("Answer") - # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - # retrieved_content.append({query: aimessages}) - # return messagss_list,retrieved_content - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - # 调用 Celery 任务,传递结构化消息列表 - # 数据流: - # 1. structured_messages 传递给 write_message_task - # 2. write_message_task 调用 memory_agent_service.write_memory - # 3. write_memory 调用 write_tools.write,传递 messages 参数 - # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 - # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 - # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -288,30 +182,6 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') -# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# db_for_memory = next(get_db()) -# if memory_flag: -# if len(history_term_memory)>=4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# print(retrieved_content) -# # 为长期记忆操作获取新的数据库连接 -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# except Exception as e: -# logger.error(f"Failed to write to LongTermMemory: {e}") -# raise -# finally: -# db_for_memory.close() - -# # 长期记忆写入( -# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) -# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -332,17 +202,17 @@ class LangChainAgent: # 获取最后的 AI 消息 output_messages = result.get("messages", []) content = "" + total_tokens = 0 for msg in reversed(output_messages): if isinstance(msg, AIMessage): content = msg.content + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break elapsed_time = time.time() - start_time if memory_flag: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, content) + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -350,7 +220,7 @@ class LangChainAgent: "usage": { "prompt_tokens": 0, "completion_tokens": 0, - "total_tokens": 0 + "total_tokens": total_tokens } } @@ -410,25 +280,7 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") -# # TODO 乐力齐 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# if memory_flag: -# if len(history_term_memory) >= 4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# db_for_memory = next(get_db()) -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# # 长期记忆写入 -# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) -# except Exception as e: -# logger.error(f"Failed to write to long term memory: {e}") -# finally: -# db_for_memory.close() + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: @@ -444,7 +296,7 @@ class LangChainAgent: # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") - full_content='' + full_content = '' try: async for event in self.agent.astream_events( {"messages": messages}, @@ -481,12 +333,17 @@ class LangChainAgent: logger.debug(f"工具调用结束: {event.get('name')}") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") + # 统计token消耗 + output_messages = event.get("data", {}).get("output", {}).get("messages", []) + for msg in reversed(output_messages): + if isinstance(msg, AIMessage): + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", + 0) if response_meta else 0 + yield total_tokens + break if memory_flag: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, full_content) - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/config.py b/api/app/core/config.py index a8981054..0de957c7 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -157,6 +157,11 @@ class Settings: if origin.strip() ] + # Language Configuration + # Supported values: "zh" (Chinese), "en" (English) + # This controls the language used for memory summary titles and other generated content + DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh") + # Logging settings LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py new file mode 100644 index 00000000..895f61ac --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -0,0 +1,238 @@ +import json +import os + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage + +from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel +from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ +from app.core.memory.agent.utils.redis_tool import write_store +from app.core.memory.agent.utils.redis_tool import count_store +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id +logger = get_agent_logger(__name__) +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') + +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) + + Args: + storage_type: 存储类型 (neo4j/rag) + end_user_id: 终端用户ID + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) + try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + with get_db_context() as db_session: + repo = LongTermMemoryRepository(db_session) + + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + + + +'''根据窗口''' +async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): + ''' + 根据窗口获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + langchain_messages:原始数据LIST + scope:窗口大小 + ''' + scope=scope + is_end_user_id = count_store.get_sessions_count(end_user_id) + if is_end_user_id is not False: + is_end_user_id = count_store.get_sessions_count(end_user_id)[0] + redis_messages = count_store.get_sessions_count(end_user_id)[1] + if is_end_user_id and int(is_end_user_id) != int(scope): + is_end_user_id += 1 + langchain_messages += redis_messages + count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) + elif int(is_end_user_id) == int(scope): + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + + +"""根据时间""" +async def memory_long_term_storage(end_user_id,memory_config,time): + ''' + 根据时间获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + ''' + long_time_data = write_store.find_user_recent_sessions(end_user_id, time) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message + if format_messages!=[]: + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) +'''聚合判断''' +async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: + """ + 聚合判断函数:判断输入句子和历史消息是否描述同一事件 + + Args: + end_user_id: 终端用户ID + ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + memory_config: 内存配置对象 + """ + + try: + # 1. 获取历史会话数据(使用新方法) + result = write_store.get_all_sessions_by_end_user_id(end_user_id) + history = await format_parsing(result) + if not result: + history = [] + else: + history = await format_parsing(result) + json_schema = WriteAggregateModel.model_json_schema() + template_service = TemplateService(template_root) + system_prompt = await template_service.render_template( + template_name='write_aggregate_judgment.jinja2', + operation_name='aggregate_judgment', + history=history, + sentence=ori_messages, + json_schema=json_schema + ) + with get_db_context() as db_session: + factory = MemoryClientFactory(db_session) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + messages = [ + { + "role": "user", + "content": system_prompt + } + ] + structured = await llm_client.response_structured( + messages=messages, + response_model=WriteAggregateModel + ) + output_value = structured.output + if isinstance(output_value, list): + output_value = [ + {"role": msg.role, "content": msg.content} + for msg in output_value + ] + + result_dict = { + "is_same_event": structured.is_same_event, + "output": output_value + } + if not structured.is_same_event: + logger.info(result_dict) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) + return result_dict + + except Exception as e: + print(f"[aggregate_judgment] 发生错误: {e}") + import traceback + traceback.print_exc() + + return { + "is_same_event": False, + "output": ori_messages, + "messages": ori_messages, + "history": history if 'history' in locals() else [], + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index c4814de1..fcbb18e3 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): 清理后的数据 """ # 需要过滤的字段列表 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" } if isinstance(data, dict): diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py new file mode 100644 index 00000000..9ce581ee --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -0,0 +1,72 @@ +import json + +from langchain_core.messages import HumanMessage, AIMessage +async def format_parsing(messages: list,type:str='string'): + """ + 格式化解析消息列表 + + Args: + messages: 消息列表 + type: 返回类型 ('string' 或 'dict') + + Returns: + 格式化后的消息列表 + """ + result = [] + user=[] + ai=[] + + for message in messages: + hstory_messages = message['messages'] + for history_messag in hstory_messages.strip().splitlines(): + history_messag = json.loads(history_messag) + for content in history_messag: + role = content['role'] + content = content['content'] + if type == "string": + if role == 'human' or role=="user": + content = '用户:' + content + else: + content = 'AI:' + content + result.append(content) + if type == "dict" : + if role == 'human' or role=="user": + user.append( content) + else: + ai.append(content) + if type == "dict": + for key,values in zip(user,ai): + result.append({key:values}) + return result + +async def messages_parse(messages: list | dict): + user=[] + ai=[] + database=[] + for message in messages: + Query = message['Query'] + Query = json.loads(Query) + for data in Query: + role = data['role'] + if role == "human": + user.append(data['content']) + if role == "ai": + ai.append(data['content']) + for key, values in zip(user, ai): + database.append({key, values}) + return database + + +async def agent_chat_messages(user_content,ai_content): + messages = [ + { + "role": "user", + "content": f"{user_content}" + }, + { + "role": "assistant", + "content": f"{ai_content}" + } + + ] + return messages diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 8b5de444..fd2c498c 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,22 +1,20 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager - - -from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph - -from app.db import get_db +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write +from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -34,14 +32,6 @@ async def make_write_graph(): end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - # workflow = StateGraph(WriteState) - # workflow.add_node("content_input", content_input_write) - # workflow.add_node("save_neo4j", write_node) - # workflow.add_edge(START, "content_input") - # workflow.add_edge("content_input", "save_neo4j") - # workflow.add_edge("save_neo4j", END) - # - # graph = workflow.compile() workflow = StateGraph(WriteState) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "save_neo4j") @@ -51,43 +41,63 @@ async def make_write_graph(): yield graph - -async def main(): - """主函数 - 运行工作流""" - message = "今天周一" - end_user_id = 'new_2025test1103' # 组ID - - +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, (langchain_messages)) # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=17, # 改为整数 - service_name="MemoryAgentService" - ) - try: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config} - - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j'==node_name: - massages=node_data - massages=massages.get('write_result')['status'] - print(massages) # | 更新数据: {node_data} - - except Exception as e: - import traceback - traceback.print_exc() + with get_db_context() as db_session: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" + ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) -if __name__ == "__main__": - import asyncio - asyncio.run(main()) \ No newline at end of file + +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) + +# async def main(): +# """主函数 - 运行工作流""" +# langchain_messages = [ +# { +# "role": "user", +# "content": "今天周五去爬山" +# }, +# { +# "role": "assistant", +# "content": "好耶" +# } +# +# ] +# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID +# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# +# +# +# if __name__ == "__main__": +# import asyncio +# asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/models/write_aggregate_model.py b/api/app/core/memory/agent/models/write_aggregate_model.py new file mode 100644 index 00000000..fd423314 --- /dev/null +++ b/api/app/core/memory/agent/models/write_aggregate_model.py @@ -0,0 +1,28 @@ +"""Pydantic models for write aggregate judgment operations.""" + +from typing import List, Union +from pydantic import BaseModel, Field + + +class MessageItem(BaseModel): + """Individual message item in conversation.""" + + role: str = Field(..., description="角色:user 或 assistant") + content: str = Field(..., description="消息内容") + + +class WriteAggregateResponse(BaseModel): + """Response model for aggregate judgment containing judgment result and output.""" + + is_same_event: bool = Field( + ..., + description="是否是同一事件。True表示是同一事件,False表示不同事件" + ) + output: Union[List[MessageItem], bool] = Field( + ..., + description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表" + ) + + +# 为了保持向后兼容,保留旧的类名作为别名 +WriteAggregateModel = WriteAggregateResponse diff --git a/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 new file mode 100644 index 00000000..fb0247aa --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 @@ -0,0 +1,57 @@ +输入句子:{{sentence}} +历史消息:{{history}} + +# 你的角色 +你是一个擅长事件聚合与语义判断的专家。 + +# 你的任务 +结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。 + +以下情况视为"同一事件"(需要返回 is_same_event=True, output=False): +- 描述的是同一个具体事件或事实 +- 存在明显的因果关系、前后发展关系 +- 是对同一事件的补充、解释、追问或延展 +- 逻辑上属于同一语境下的连续讨论 + +以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表): +- 话题不同,事件主体不同 +- 时间、地点、对象明显不同 +- 只是语义相似,但并非同一具体事件 +- 无直接事件、因果或逻辑关联 + +# 输出规则(非常重要) +你必须按照以下JSON格式输出: + +**如果是同一事件:** +```json +{ + "is_same_event": true, + "output": false +} +``` + +**如果不是同一事件:** +```json +{ + "is_same_event": false, + "output": [ + { + "role": "user", + "content": "输入句子的内容" + }, + { + "role": "assistant", + "content": "对应的回复内容" + } + ] +} +``` + +# JSON Schema +{{json_schema}} + +# 注意事项 +- 必须严格按照上述格式输出 +- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表 +- 消息列表必须包含 role 和 content 字段 +- 不要输出任何解释、分析或多余内容 diff --git a/api/app/core/memory/agent/utils/redis_base.py b/api/app/core/memory/agent/utils/redis_base.py new file mode 100644 index 00000000..59bac109 --- /dev/null +++ b/api/app/core/memory/agent/utils/redis_base.py @@ -0,0 +1,186 @@ +import json +from typing import Any, List, Dict, Optional +from datetime import datetime, timedelta + + +def serialize_messages(messages: Any) -> str: + """ + 将消息序列化为 JSON 字符串,支持 LangChain 消息对象 + + Args: + messages: 可以是 list、dict、string 或 LangChain 消息对象列表 + + Returns: + str: JSON 字符串 + """ + if isinstance(messages, str): + return messages + + if isinstance(messages, (list, tuple)): + # 检查是否是 LangChain 消息对象列表 + serialized_list = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # LangChain 消息对象 + serialized_list.append({ + 'type': msg.type, + 'content': msg.content, + 'role': getattr(msg, 'role', msg.type) + }) + elif isinstance(msg, dict): + serialized_list.append(msg) + else: + serialized_list.append(str(msg)) + return json.dumps(serialized_list, ensure_ascii=False) + + if isinstance(messages, dict): + return json.dumps(messages, ensure_ascii=False) + + # 其他类型转为字符串 + return str(messages) + + +def deserialize_messages(messages_str: str) -> Any: + """ + 将 JSON 字符串反序列化为原始格式 + + Args: + messages_str: JSON 字符串 + + Returns: + 反序列化后的对象(list、dict 或 string) + """ + if not messages_str: + return [] + + try: + return json.loads(messages_str) + except (json.JSONDecodeError, TypeError): + return messages_str + + +def fix_encoding(text: str) -> str: + """ + 修复错误编码的文本 + + Args: + text: 需要修复的文本 + + Returns: + str: 修复后的文本 + """ + if not text or not isinstance(text, str): + return text + try: + # 尝试修复 Latin-1 误编码为 UTF-8 的情况 + return text.encode('latin-1').decode('utf-8') + except (UnicodeDecodeError, UnicodeEncodeError): + # 如果修复失败,返回原文本 + return text + + +def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]: + """ + 格式化会话数据为统一的输出格式 + + Args: + data: 原始会话数据 + include_time: 是否包含时间字段 + + Returns: + Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."} + """ + result = { + "Query": fix_encoding(data.get('messages', '')), + "Answer": fix_encoding(data.get('aimessages', '')) + } + + if include_time: + result["starttime"] = data.get('starttime', '') + + return result + + +def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]: + """ + 根据时间范围过滤数据 + + Args: + items: 包含 starttime 字段的数据列表 + minutes: 时间范围(分钟) + + Returns: + List[Dict]: 过滤后的数据列表 + """ + time_threshold = datetime.now() - timedelta(minutes=minutes) + time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S") + + filtered_items = [] + for item in items: + starttime = item.get('starttime', '') + if starttime and starttime >= time_threshold_str: + filtered_items.append(item) + + return filtered_items + + +def sort_and_limit_results(items: List[Dict], limit: int = 6, + remove_time: bool = True) -> List[Dict]: + """ + 对结果进行排序、限制数量并移除时间字段 + + Args: + items: 数据列表 + limit: 最大返回数量 + remove_time: 是否移除 starttime 字段 + + Returns: + List[Dict]: 处理后的数据列表 + """ + # 按时间降序排序(最新的在前) + items.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + # 限制数量 + result_items = items[:limit] + + # 移除 starttime 字段 + if remove_time: + for item in result_items: + item.pop('starttime', None) + + # 如果结果少于1条,返回空列表 + if len(result_items) < 1: + return [] + + return result_items + + +def generate_session_key(session_id: str, key_type: str = "session") -> str: + """ + 生成 Redis key + + Args: + session_id: 会话ID + key_type: key 类型 ("session", "read", "write", "count") + + Returns: + str: Redis key + """ + if key_type == "count": + return f"session:count:{session_id}" + elif key_type == "write": + return f"session:write:{session_id}" + elif key_type == "session" or key_type == "read": + return f"session:{session_id}" + else: + return f"session:{session_id}" + + +def get_current_timestamp() -> str: + """ + 获取当前时间戳字符串 + + Returns: + str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS" + """ + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index 505545b3..c5729628 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -1,11 +1,36 @@ import redis import uuid -from datetime import datetime from app.core.config import settings +from typing import List, Dict, Any, Optional, Union + +from app.core.memory.agent.utils.redis_base import ( + serialize_messages, + deserialize_messages, + fix_encoding, + format_session_data, + filter_by_time_range, + sort_and_limit_results, + generate_session_key, + get_current_timestamp +) -class RedisSessionStore: + + +class RedisWriteStore: + """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ self.r = redis.Redis( host=host, port=port, @@ -16,32 +41,437 @@ class RedisSessionStore: ) self.uudi = session_id - def _fix_encoding(self, text): - """修复错误编码的文本""" - if not text or not isinstance(text, str): - return text - try: - # 尝试修复 Latin-1 误编码为 UTF-8 的情况 - return text.encode('latin-1').decode('utf-8') - except (UnicodeDecodeError, UnicodeEncodeError): - # 如果修复失败,返回原文本 - return text - - # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, end_user_id): + def save_session_write(self, userid: str, messages: str) -> str: """ 写入一条会话数据,返回 session_id - 优化版本:确保写入时间不超过1秒 + + Args: + userid: 用户ID + messages: 用户消息 + + Returns: + str: 新生成的 session_id """ try: - session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" # 使用新生成的 session_id 作为 key + messages = serialize_messages(messages) + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="write") - # 使用 pipeline 批量写入,减少网络往返 pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "sessionid": userid, + "messages": messages, + "starttime": get_current_timestamp() + }) + result = pipe.execute() - # 直接写入数据,decode_responses=True 已经处理了编码 + print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + return session_id + except Exception as e: + print(f"[save_session_write] 保存会话失败: {e}") + raise e + + def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: + """ + 通过 save_session_write 的 userid 获取 sessionid 和 messages + + Args: + userid: 用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + results.append({ + "sessionid": session_id, + "messages": fix_encoding(data.get('messages', '')) + }) + + if not results: + return False + + print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_session_by_userid] 查询失败: {e}") + return False + + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: + """ + 通过 end_user_id 获取所有 write 类型的会话数据 + + Args: + end_user_id: 终端用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False + + 返回格式: + [ + { + "session_id": "uuid", + "id": "...", + "sessionid": "end_user_id", + "messages": "...", + "starttime": "timestamp" + }, + ... + ] + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 end_user_id 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == end_user_id: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + + # 构建完整的会话信息 + session_info = { + "session_id": session_id, + "id": data.get('id', ''), + "sessionid": data.get('sessionid', ''), + "messages": fix_encoding(data.get('messages', '')), + "starttime": data.get('starttime', '') + } + results.append(session_info) + + if not results: + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + return False + + # 按时间排序(最新的在前) + results.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") + import traceback + traceback.print_exc() + return False + + def find_user_recent_sessions(self, userid: str, + minutes: int = 5) -> List[Dict[str, str]]: + """ + 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 + + Args: + userid: 用户ID (对应 sessionid 字段) + minutes: 查询最近几分钟的数据,默认5分钟 + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid and data.get('starttime'): + # write 类型没有 aimessages,所以 Answer 为空 + matched_items.append({ + "Query": fix_encoding(data.get('messages', '')), + "Answer": "", + "starttime": data.get('starttime', '') + }) + + # 根据时间范围过滤 + filtered_items = filter_by_time_range(matched_items, minutes) + # 排序并移除时间字段 + result_items = sort_and_limit_results(filtered_items, limit=None) + print(result_items) + + elapsed_time = time.time() - start_time + print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + def delete_all_write_sessions(self) -> int: + """ + 删除所有 write 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:write:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisCountStore: + """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: + """ + 保存用户访问次数统计 + + Args: + end_user_id: 终端用户ID + count: 访问次数 + messages: 消息内容 + + Returns: + str: 新生成的 session_id + """ + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="count") + index_key = f'session:count:index:{end_user_id}' # 索引键 + + pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "end_user_id": end_user_id, + "count": int(count), + "messages": serialize_messages(messages), + "starttime": get_current_timestamp() + }) + pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + + # 创建索引:end_user_id -> session_id 映射 + pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) + + result = pipe.execute() + + print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + return session_id + + def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + """ + 通过 end_user_id 查询访问次数统计 + + Args: + end_user_id: 终端用户ID + + Returns: + list 或 False: 如果找到返回 [count, messages],否则返回 False + """ + try: + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' + + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + self.r.delete(index_key) + return False + except Exception as type_error: + print(f"[get_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + return False + + # 直接获取数据 + key = generate_session_key(session_id, key_type="count") + data = self.r.hgetall(key) + + if not data: + # 索引存在但数据不存在,清理索引 + self.r.delete(index_key) + return False + + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] + + return False + except Exception as e: + print(f"[get_sessions_count] 查询失败: {e}") + return False + def update_sessions_count(self, end_user_id: str, new_count: int, + messages: Any) -> bool: + """ + 通过 end_user_id 修改访问次数统计(优化版:使用索引) + + Args: + end_user_id: 终端用户ID + new_count: 新的 count 值 + messages: 消息内容 + + Returns: + bool: 更新成功返回 True,未找到记录返回 False + """ + try: + # 使用索引键快速查找 + index_key = f'session:count:index:{end_user_id}' + + # 检查索引键类型,避免 WRONGTYPE 错误 + try: + key_type = self.r.type(index_key) + if key_type != 'string' and key_type != 'none': + # 索引键类型错误,删除并返回 False + print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + self.r.delete(index_key) + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as type_error: + print(f"[update_sessions_count] 检查键类型失败: {type_error}") + + session_id = self.r.get(index_key) + + if not session_id: + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + + # 直接更新数据 + key = generate_session_key(session_id, key_type="count") + messages_str = serialize_messages(messages) + + pipe = self.r.pipeline() + pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'messages', messages_str) + result = pipe.execute() + + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True + + except Exception as e: + print(f"[update_sessions_count] 更新失败: {e}") + return False + + def delete_all_count_sessions(self) -> int: + """ + 删除所有 count 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:count:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisSessionStore: + """Redis 会话存储类,用于管理会话数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + # ==================== 写入操作 ==================== + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: + """ + 写入一条会话数据,返回 session_id + + Args: + userid: 用户ID + messages: 用户消息 + aimessages: AI回复消息 + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + str: 新生成的 session_id + """ + try: + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="read") + + pipe = self.r.pipeline() pipe.hset(key, mapping={ "id": self.uudi, "sessionid": userid, @@ -49,177 +479,195 @@ class RedisSessionStore: "end_user_id": end_user_id, "messages": messages, "aimessages": aimessages, - "starttime": starttime + "starttime": get_current_timestamp() }) - - # 可选:设置过期时间(例如30天),避免数据无限增长 - # pipe.expire(key, 30 * 24 * 60 * 60) - - # 执行批量操作 result = pipe.execute() - print(f"保存结果: {result[0]}, session_id: {session_id}") - return session_id # 返回新生成的 session_id + print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + return session_id except Exception as e: - print(f"保存会话失败: {e}") + print(f"[save_session] 保存会话失败: {e}") raise e - def save_sessions_batch(self, sessions_data): - """ - 批量写入多条会话数据,返回 session_id 列表 - sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id - 优化版本:批量操作,大幅提升性能 - """ - try: - session_ids = [] - pipe = self.r.pipeline() - - for session in sessions_data: - session_id = str(uuid.uuid4()) - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" - - pipe.hset(key, mapping={ - "id": self.uudi, - "sessionid": session.get('userid'), - "apply_id": session.get('apply_id'), - "end_user_id": session.get('end_user_id'), - "messages": session.get('messages'), - "aimessages": session.get('aimessages'), - "starttime": starttime - }) - - session_ids.append(session_id) - - # 一次性执行所有写入操作 - results = pipe.execute() - print(f"批量保存完成: {len(session_ids)} 条记录") - return session_ids - except Exception as e: - print(f"批量保存会话失败: {e}") - raise e - - # ---------------- 读取 ---------------- - def get_session(self, session_id): + # ==================== 读取操作 ==================== + + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 + + Args: + session_id: 会话ID + + Returns: + Dict 或 None: 会话数据 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) data = self.r.hgetall(key) return data if data else None - def get_session_apply_group(self, sessionid, apply_id, end_user_id): + def get_all_sessions(self) -> Dict[str, Dict[str, Any]]: """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据 - """ - result_items = [] - - # 遍历所有会话数据 - for key in self.r.keys('session:*'): - data = self.r.hgetall(key) - - if not data: - continue - - # 检查三个条件是否都匹配 - if (data.get('sessionid') == sessionid and - data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - result_items.append(data) - - return result_items - - def get_all_sessions(self): - """ - 获取所有会话数据 + 获取所有会话数据(不包括 count 和 write 类型) + + Returns: + Dict: 所有会话数据,key 为 session_id """ sessions = {} for key in self.r.keys('session:*'): - sid = key.split(':')[1] - sessions[sid] = self.get_session(sid) + # 排除 count 和 write 类型的 key + if ':count:' not in key and ':write:' not in key: + sid = key.split(':')[1] + sessions[sid] = self.get_session(sid) return sessions - # ---------------- 更新 ---------------- - def update_session(self, session_id, field, value): + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: + """ + 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 + + Args: + sessionid: 会话ID(支持模糊匹配) + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + keys = self.r.keys('session:*') + if not keys: + print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合条件的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + if (data.get('apply_id') == apply_id and + data.get('end_user_id') == end_user_id): + # 支持模糊匹配或完全匹配 sessionid + if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: + matched_items.append(format_session_data(data, include_time=True)) + + # 排序、限制数量并移除时间字段 + result_items = sort_and_limit_results(matched_items, limit=6) + + elapsed_time = time.time() - start_time + print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + # ==================== 更新操作 ==================== + + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 - 优化版本:使用 pipeline 减少网络往返 + + Args: + session_id: 会话ID + field: 字段名 + value: 字段值 + + Returns: + bool: 是否更新成功 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) pipe = self.r.pipeline() pipe.exists(key) pipe.hset(key, field, value) results = pipe.execute() - return bool(results[0]) # 返回 key 是否存在 + return bool(results[0]) - # ---------------- 删除 ---------------- - def delete_session(self, session_id): + # ==================== 删除操作 ==================== + + def delete_session(self, session_id: str) -> int: """ 删除单条会话 + + Args: + session_id: 会话ID + + Returns: + int: 删除的数量 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) return self.r.delete(key) - def delete_all_sessions(self): + def delete_all_sessions(self) -> int: """ - 删除所有会话 + 删除所有会话(不包括 count 和 write 类型) + + Returns: + int: 删除的数量 """ keys = self.r.keys('session:*') - if keys: - return self.r.delete(*keys) + # 过滤掉 count 和 write 类型 + keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k] + if keys_to_delete: + return self.r.delete(*keys_to_delete) return 0 - def delete_duplicate_sessions(self): + def delete_duplicate_sessions(self) -> int: """ - 删除重复会话数据,条件: - "sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 - 优化版本:使用 pipeline 批量操作,确保在1秒内完成 + 删除重复会话数据(不包括 count 和 write 类型) + 条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个 + + Returns: + int: 删除的数量 """ import time start_time = time.time() - # 第一步:使用 pipeline 批量获取所有 key keys = self.r.keys('session:*') - if not keys: print("[delete_duplicate_sessions] 没有会话数据") return 0 - # 第二步:使用 pipeline 批量获取所有数据 + # 批量获取所有数据 pipe = self.r.pipeline() for key in keys: - pipe.hgetall(key) + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) all_data = pipe.execute() - # 第三步:在内存中识别重复数据 - seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key) - keys_to_delete = [] # 需要删除的 key 列表 + # 识别重复数据 + seen = {} + keys_to_delete = [] - for key, data in zip(keys, all_data, strict=False): + for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False): if not data: continue - # 获取五个字段的值 - sessionid = data.get('sessionid', '') - user_id = data.get('id', '') - end_user_id = data.get('end_user_id', '') - messages = data.get('messages', '') - aimessages = data.get('aimessages', '') - # 用五元组作为唯一标识 - identifier = (sessionid, user_id, end_user_id, messages, aimessages) + identifier = ( + data.get('sessionid', ''), + data.get('id', ''), + data.get('end_user_id', ''), + data.get('messages', ''), + data.get('aimessages', '') + ) if identifier in seen: - # 重复,标记为待删除 keys_to_delete.append(key) else: - # 第一次出现,记录 seen[identifier] = key - # 第四步:使用 pipeline 批量删除重复的 key + # 批量删除重复的 key deleted_count = 0 if keys_to_delete: - # 分批删除,避免单次操作过大 batch_size = 1000 for i in range(0, len(keys_to_delete), batch_size): batch = keys_to_delete[i:i + batch_size] @@ -233,79 +681,28 @@ class RedisSessionStore: print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count - def find_user_session(self, sessionid): - user_id = sessionid - - result_items = [] - for key, values in store.get_all_sessions().items(): - history = {} - if user_id == str(values['sessionid']): - history["Query"] = values['messages'] - history["Answer"] = values['aimessages'] - result_items.append(history) - - if len(result_items) <= 1: - result_items = [] - return (result_items) - - def find_user_apply_group(self, sessionid, apply_id, end_user_id): - """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条 - """ - import time - start_time = time.time() - # 使用 pipeline 批量获取数据,提高性能 - keys = self.r.keys('session:*') - - if not keys: - print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") - return [] - - # 使用 pipeline 批量获取所有 hash 数据 - pipe = self.r.pipeline() - for key in keys: - pipe.hgetall(key) - all_data = pipe.execute() - - # 解析并筛选符合条件的数据 - matched_items = [] - for data in all_data: - if not data: - continue - - # 检查是否符合三个条件 - - if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - # 支持模糊匹配 sessionid 或者完全匹配 - if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: - matched_items.append({ - "Query": self._fix_encoding(data.get('messages')), - "Answer": self._fix_encoding(data.get('aimessages')), - "starttime": data.get('starttime', '') - }) - # 按时间降序排序(最新的在前) - matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True) - # 只保留最新的6条 - result_items = matched_items[:6] - # # 移除 starttime 字段 - for item in result_items: - item.pop('starttime', None) - - # 如果结果少于等于1条,返回空列表 - if len(result_items) <= 1: - result_items = [] - - elapsed_time = time.time() - start_time - print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") - - return result_items - +# 全局实例 store = RedisSessionStore( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, session_id=str(uuid.uuid4()) -) \ No newline at end of file +) + +write_store = RedisWriteStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) + +count_store = RedisCountStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 446ab86a..aa66014c 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline This module provides the main write function for executing the knowledge extraction pipeline. Only MemoryConfig is needed - clients are constructed internally. """ +import asyncio import time from datetime import datetime @@ -123,23 +124,48 @@ async def write( except Exception as e: logger.error(f"Error creating indexes: {e}", exc_info=True) + # 添加死锁重试机制 + max_retries = 3 + retry_delay = 1 # 秒 + + for attempt in range(max_retries): + try: + success = await save_dialog_and_statements_to_neo4j( + dialogue_nodes=all_dialogue_nodes, + chunk_nodes=all_chunk_nodes, + statement_nodes=all_statement_nodes, + entity_nodes=all_entity_nodes, + statement_chunk_edges=all_statement_chunk_edges, + statement_entity_edges=all_statement_entity_edges, + entity_edges=all_entity_entity_edges, + connector=neo4j_connector + ) + if success: + logger.info("Successfully saved all data to Neo4j") + break + else: + logger.warning("Failed to save some data to Neo4j") + if attempt < max_retries - 1: + logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + except Exception as e: + error_msg = str(e) + # 检查是否是死锁错误 + if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower(): + if attempt < max_retries - 1: + logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + else: + logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}") + raise + else: + # 非死锁错误,直接抛出 + raise + try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=all_dialogue_nodes, - chunk_nodes=all_chunk_nodes, - statement_nodes=all_statement_nodes, - entity_nodes=all_entity_nodes, - statement_chunk_edges=all_statement_chunk_edges, - statement_entity_edges=all_statement_entity_edges, - entity_edges=all_entity_entity_edges, - connector=neo4j_connector - ) - if success: - logger.info("Successfully saved all data to Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - finally: await neo4j_connector.close() + except Exception as e: + logger.error(f"Error closing Neo4j connector: {e}") log_time("Neo4j Database Save", time.time() - step_start, log_file) diff --git a/api/app/core/memory/models/__init__.py b/api/app/core/memory/models/__init__.py index 1de3424a..8c573b7a 100644 --- a/api/app/core/memory/models/__init__.py +++ b/api/app/core/memory/models/__init__.py @@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import ( TripletExtractionResponse, ) +# Ontology models +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) + # Variable configuration models from app.core.memory.models.variate_config import ( StatementExtractionConfig, @@ -105,6 +111,9 @@ __all__ = [ "Entity", "Triplet", "TripletExtractionResponse", + # Ontology models + "OntologyClass", + "OntologyExtractionResponse", # Variable configuration "StatementExtractionConfig", "ForgettingEngineConfig", diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 79b88fdc..1880b9ab 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -413,7 +413,8 @@ class ExtractedEntityNode(Node): description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(default="", description="Summary of the fact about this entity") + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") diff --git a/api/app/core/memory/models/ontology_models.py b/api/app/core/memory/models/ontology_models.py new file mode 100644 index 00000000..24a61f5f --- /dev/null +++ b/api/app/core/memory/models/ontology_models.py @@ -0,0 +1,135 @@ +"""Models for ontology classes and extraction responses. + +This module contains Pydantic models for representing extracted ontology classes +from scenario descriptions, following OWL ontology engineering standards. + +Classes: + OntologyClass: Represents an extracted ontology class + OntologyExtractionResponse: Response model containing extracted ontology classes +""" + +from typing import List, Optional +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class OntologyClass(BaseModel): + """Represents an extracted ontology class from scenario description. + + An ontology class represents an abstract category or concept in a domain, + following OWL ontology engineering standards and naming conventions. + + Attributes: + id: Unique string identifier for the ontology class + name: Name of the class in PascalCase format (e.g., 'MedicalProcedure') + name_chinese: Chinese translation of the class name (e.g., '医疗程序') + description: Textual description of the class + examples: List of concrete instance examples of this class + parent_class: Optional name of the parent class in the hierarchy + entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept') + domain: Domain this class belongs to (e.g., 'Healthcare', 'Education') + + Config: + extra: Ignore extra fields from LLM output + """ + model_config = ConfigDict(extra='ignore') + + id: str = Field( + default_factory=lambda: uuid4().hex, + description="Unique identifier for the ontology class" + ) + name: str = Field( + ..., + description="Name of the class in PascalCase format" + ) + name_chinese: Optional[str] = Field( + None, + description="Chinese translation of the class name" + ) + description: str = Field( + ..., + description="Description of the class" + ) + examples: List[str] = Field( + default_factory=list, + description="List of concrete instance examples" + ) + parent_class: Optional[str] = Field( + None, + description="Name of the parent class in the hierarchy" + ) + entity_type: str = Field( + ..., + description="Type/category of the entity" + ) + domain: str = Field( + ..., + description="Domain this class belongs to" + ) + + @field_validator('name') + @classmethod + def validate_pascal_case(cls, v: str) -> str: + """Validate that the class name follows PascalCase convention. + + PascalCase rules: + - Must start with an uppercase letter + - Cannot contain spaces + - Should not contain special characters except underscores + + Args: + v: The class name to validate + + Returns: + The validated class name + + Raises: + ValueError: If the name doesn't follow PascalCase convention + """ + if not v: + raise ValueError("Class name cannot be empty") + + if not v[0].isupper(): + raise ValueError( + f"Class name '{v}' must start with an uppercase letter (PascalCase)" + ) + + if ' ' in v: + raise ValueError( + f"Class name '{v}' cannot contain spaces (PascalCase)" + ) + + # Check for invalid characters (allow alphanumeric and underscore only) + if not all(c.isalnum() or c == '_' for c in v): + raise ValueError( + f"Class name '{v}' contains invalid characters. " + "Only alphanumeric characters and underscores are allowed" + ) + + return v + + +class OntologyExtractionResponse(BaseModel): + """Response model for ontology extraction from LLM. + + This model represents the structured output from the LLM when + extracting ontology classes from scenario descriptions. + + Attributes: + classes: List of extracted ontology classes + domain: Domain/field the scenario belongs to + + Config: + extra: Ignore extra fields from LLM output + """ + model_config = ConfigDict(extra='ignore') + + classes: List[OntologyClass] = Field( + default_factory=list, + description="List of extracted ontology classes" + ) + domain: str = Field( + ..., + description="Domain/field the scenario belongs to" + ) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index a425e0ed..f2f14d9e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): if len(desc_b) > len(desc_a): canonical.description = desc_b # 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序 - fact_a = getattr(canonical, "fact_summary", "") or "" - fact_b = getattr(ent, "fact_summary", "") or "" - def _extract_sources(txt: str) -> List[str]: - sources: List[str] = [] - if not txt: - return sources - for line in str(txt).splitlines(): - ln = line.strip() + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = getattr(canonical, "fact_summary", "") or "" + # fact_b = getattr(ent, "fact_summary", "") or "" + # def _extract_sources(txt: str) -> List[str]: + # sources: List[str] = [] + # if not txt: + # return sources + # for line in str(txt).splitlines(): + # ln = line.strip() # 支持“来源:”或“来源:”前缀 - m = re.match(r"^来源[::]\s*(.+)$", ln) - if m: - content = m.group(1).strip() - if content: - sources.append(content) + # m = re.match(r"^来源[::]\s*(.+)$", ln) + # if m: + # content = m.group(1).strip() + # if content: + # sources.append(content) # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 - if not sources and txt.strip(): - sources.append(txt.strip()) - return sources + # if not sources and txt.strip(): + # sources.append(txt.strip()) + # return sources try: - src_a = _extract_sources(fact_a) - src_b = _extract_sources(fact_b) - seen = set() - merged_sources: List[str] = [] - for s in src_a + src_b: - if s and s not in seen: - seen.add(s) - merged_sources.append(s) - if merged_sources: - name_line = f"实体: {getattr(canonical, 'name', '')}".strip() - canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) - elif fact_b and not fact_a: - canonical.fact_summary = fact_b + # src_a = _extract_sources(fact_a) + # src_b = _extract_sources(fact_b) + # seen = set() + # merged_sources: List[str] = [] + # for s in src_a + src_b: + # if s and s not in seen: + # seen.add(s) + # merged_sources.append(s) + # if merged_sources: + # name_line = f"实体: {getattr(canonical, 'name', '')}".strip() + # canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) + # elif fact_b and not fact_a: + # canonical.fact_summary = fact_b + pass except Exception: # 兜底:若解析失败,保留较长文本 - if len(fact_b) > len(fact_a): - canonical.fact_summary = fact_b + # if len(fact_b) > len(fact_a): + # canonical.fact_summary = fact_b + pass except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 0249ac1f..a028e916 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) desc_a = (getattr(a, "description", "") or "") desc_b = (getattr(b, "description", "") or "") - fact_a = (getattr(a, "fact_summary", "") or "") - fact_b = (getattr(b, "fact_summary", "") or "") - score_a = len(desc_a) + len(fact_a) - score_b = len(desc_b) + len(fact_b) + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_a = (getattr(a, "fact_summary", "") or "") + # fact_b = (getattr(b, "fact_summary", "") or "") + # score_a = len(desc_a) + len(fact_a) + # score_b = len(desc_b) + len(fact_b) + score_a = len(desc_a) + score_b = len(desc_b) if score_a != score_b: return 0 if score_a >= score_b else 1 return 0 @@ -189,7 +192,8 @@ async def _judge_pair( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -197,7 +201,8 @@ async def _judge_pair( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } # 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式) @@ -248,7 +253,8 @@ async def _judge_pair_disamb( "entity_type": getattr(a, "entity_type", None), "description": getattr(a, "description", None), "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(a, "fact_summary", None), "connect_strength": getattr(a, "connect_strength", None), } entity_b = { @@ -256,7 +262,8 @@ async def _judge_pair_disamb( "entity_type": getattr(b, "entity_type", None), "description": getattr(b, "description", None), "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # "fact_summary": getattr(b, "fact_summary", None), "connect_strength": getattr(b, "connect_strength", None), } prompt = render_entity_dedup_prompt( diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index dbc697d9..028a926f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: description=row.get("description") or "", aliases=row.get("aliases") or [], name_embedding=row.get("name_embedding") or [], - fact_summary=row.get("fact_summary") or "", + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=row.get("fact_summary") or "", connect_strength=row.get("connect_strength") or "", ) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 7b7e854b..8a99cb40 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1085,7 +1085,8 @@ class ExtractionOrchestrator: entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type description=getattr(entity, 'description', ''), # 添加必需的 description 字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段 - fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py index 53815124..0bc09622 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py @@ -8,4 +8,5 @@ - TemporalExtractor: 时间信息提取 - EmbeddingGenerator: 嵌入向量生成 - MemorySummaryGenerator: 记忆摘要生成 +- OntologyExtractor: 本体类提取 """ diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index f39313a8..58633363 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -14,6 +14,34 @@ from pydantic import Field logger = get_memory_logger(__name__) +# 支持的语言列表和默认回退值 +SUPPORTED_LANGUAGES = {"zh", "en"} +FALLBACK_LANGUAGE = "en" + + +def validate_language(language: Optional[str]) -> str: + """ + 校验语言参数,确保其为有效值。 + + Args: + language: 待校验的语言代码 + + Returns: + 有效的语言代码("zh" 或 "en") + """ + if language is None: + return FALLBACK_LANGUAGE + + lang = str(language).lower().strip() + if lang in SUPPORTED_LANGUAGES: + return lang + + logger.warning( + f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。" + f"支持的语言: {SUPPORTED_LANGUAGES}" + ) + return FALLBACK_LANGUAGE + class MemorySummaryResponse(RobustLLMResponse): """Structured response for summary generation per chunk. @@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse): async def generate_title_and_type_for_summary( content: str, - llm_client + llm_client, + language: str = None ) -> Tuple[str, str]: """ 为MemorySummary生成标题和类型 @@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary( Args: content: Summary的内容文本 llm_client: LLM客户端实例 + language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取 Returns: (标题, 类型)元组 """ from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt + from app.core.config import settings + + # 如果没有指定语言,从配置中读取,并校验有效性 + if language is None: + language = settings.DEFAULT_LANGUAGE + language = validate_language(language) # 定义有效的类型集合 VALID_TYPES = { @@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary( } DEFAULT_TYPE = "conversation" # 默认类型 + # 根据语言设置默认标题 + DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content" + PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed" + ERROR_TITLE = "错误" if language == "zh" else "Error" + UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title" + try: if not content: - logger.warning("content为空,无法生成标题和类型") - return ("空内容", DEFAULT_TYPE) + logger.warning(f"content为空,无法生成标题和类型 (language={language})") + return (DEFAULT_TITLE, DEFAULT_TYPE) - # 1. 渲染Jinja2提示词模板 - prompt = await render_episodic_title_and_type_prompt(content) + # 1. 渲染Jinja2提示词模板,传递语言参数 + prompt = await render_episodic_title_and_type_prompt(content, language=language) # 2. 调用LLM生成标题和类型 messages = [ @@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary( json_str = json_str.strip() result_data = json.loads(json_str) - title = result_data.get("title", "未知标题") + title = result_data.get("title", UNKNOWN_TITLE) episodic_type_raw = result_data.get("type", DEFAULT_TYPE) # 5. 校验和归一化类型 @@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary( f"已归一化为 '{episodic_type}'" ) - logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}") + logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}") return (title, episodic_type) except json.JSONDecodeError: - logger.error(f"无法解析LLM响应为JSON: {full_response}") - return ("解析失败", DEFAULT_TYPE) + logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}") + return (PARSE_ERROR_TITLE, DEFAULT_TYPE) except Exception as e: - logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True) - return ("错误", DEFAULT_TYPE) + logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True) + return (ERROR_TITLE, DEFAULT_TYPE) async def _process_chunk_summary( dialog: DialogData, @@ -153,11 +195,16 @@ async def _process_chunk_summary( return None try: + # 从配置中获取语言设置(只获取一次,复用),并校验有效性 + from app.core.config import settings + language = validate_language(settings.DEFAULT_LANGUAGE) + # Render prompt via Jinja2 for a single chunk prompt_content = await render_memory_summary_prompt( chunk_texts=chunk.content, json_schema=MemorySummaryResponse.model_json_schema(), max_words=200, + language=language, ) messages = [ @@ -178,9 +225,10 @@ async def _process_chunk_summary( try: title, episodic_type = await generate_title_and_type_for_summary( content=summary_text, - llm_client=llm_client + llm_client=llm_client, + language=language ) - logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}") + logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}") except Exception as e: logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}") # Continue without title and type diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py new file mode 100644 index 00000000..d1b79bd1 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py @@ -0,0 +1,482 @@ +"""Ontology class extraction from scenario descriptions using LLM. + +This module provides the OntologyExtractor class for extracting ontology classes +from natural language scenario descriptions. It uses LLM-driven extraction combined +with two-layer validation (string validation + OWL semantic validation). + +Classes: + OntologyExtractor: Extracts ontology classes from scenario descriptions +""" + +import asyncio +import logging +import time +from typing import List, Optional + +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) +from app.core.memory.utils.validation.ontology_validator import OntologyValidator +from app.core.memory.utils.validation.owl_validator import OWLValidator +from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt + + +logger = logging.getLogger(__name__) + + +class OntologyExtractor: + """Extractor for ontology classes from scenario descriptions. + + This extractor uses LLM to identify abstract classes and concepts from + natural language scenario descriptions, following OWL ontology engineering + standards. It performs two-layer validation: + 1. String validation (naming conventions, reserved words, duplicates) + 2. OWL semantic validation (consistency checking, circular inheritance) + + Attributes: + llm_client: OpenAI client for LLM calls + validator: String validator for class names and descriptions + owl_validator: OWL validator for semantic validation + """ + + def __init__(self, llm_client: OpenAIClient): + """Initialize the OntologyExtractor. + + Args: + llm_client: OpenAIClient instance for LLM processing + """ + self.llm_client = llm_client + self.validator = OntologyValidator() + self.owl_validator = OWLValidator() + + logger.info("OntologyExtractor initialized") + + async def extract_ontology_classes( + self, + scenario: str, + domain: Optional[str] = None, + max_classes: int = 15, + min_classes: int = 5, + enable_owl_validation: bool = True, + llm_temperature: float = 0.3, + llm_max_tokens: int = 2000, + max_description_length: int = 500, + timeout: Optional[float] = None, + ) -> OntologyExtractionResponse: + """Extract ontology classes from a scenario description. + + This is the main extraction method that orchestrates the entire process: + 1. Call LLM to extract ontology classes + 2. Perform first-layer validation (string validation and cleaning) + 3. Perform second-layer validation (OWL semantic validation) + 4. Filter invalid classes based on validation errors + 5. Return validated ontology classes + + Args: + scenario: Natural language scenario description + domain: Optional domain hint (e.g., "Healthcare", "Education") + max_classes: Maximum number of classes to extract (default: 15) + min_classes: Minimum number of classes to extract (default: 5) + enable_owl_validation: Whether to enable OWL validation (default: True) + llm_temperature: LLM temperature parameter (default: 0.3) + llm_max_tokens: LLM max tokens parameter (default: 2000) + max_description_length: Maximum description length (default: 500) + timeout: Optional timeout in seconds for LLM call (default: None, no timeout) + + Returns: + OntologyExtractionResponse containing validated ontology classes + + Raises: + ValueError: If scenario is empty or invalid + asyncio.TimeoutError: If extraction times out + + Examples: + >>> extractor = OntologyExtractor(llm_client) + >>> response = await extractor.extract_ontology_classes( + ... scenario="A hospital manages patient records...", + ... domain="Healthcare", + ... max_classes=10, + ... timeout=30.0 + ... ) + >>> len(response.classes) + 7 + """ + # Start timing + start_time = time.time() + + # Validate input + if not scenario or not scenario.strip(): + logger.error("Scenario description is empty") + raise ValueError("Scenario description cannot be empty") + + scenario = scenario.strip() + + logger.info( + f"Starting ontology extraction - scenario_length={len(scenario)}, " + f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, " + f"timeout={timeout}" + ) + + try: + # Step 1: Call LLM for extraction with timeout + logger.info("Step 1: Calling LLM for ontology extraction") + llm_start_time = time.time() + + if timeout is not None: + # Wrap LLM call with timeout + try: + response = await asyncio.wait_for( + self._call_llm_for_extraction( + scenario=scenario, + domain=domain, + max_classes=max_classes, + llm_temperature=llm_temperature, + llm_max_tokens=llm_max_tokens, + ), + timeout=timeout + ) + except asyncio.TimeoutError: + llm_duration = time.time() - llm_start_time + logger.error( + f"LLM extraction timed out after {timeout} seconds " + f"(actual duration: {llm_duration:.2f}s)" + ) + # Return empty response on timeout + return OntologyExtractionResponse( + classes=[], + domain=domain or "Unknown", + ) + else: + # No timeout specified, call directly + response = await self._call_llm_for_extraction( + scenario=scenario, + domain=domain, + max_classes=max_classes, + llm_temperature=llm_temperature, + llm_max_tokens=llm_max_tokens, + ) + + llm_duration = time.time() - llm_start_time + logger.info( + f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s" + ) + + # Step 2: First-layer validation (string validation and cleaning) + logger.info("Step 2: Performing first-layer validation (string validation)") + validation_start_time = time.time() + + response = self._validate_and_clean( + response=response, + max_description_length=max_description_length, + ) + + validation_duration = time.time() - validation_start_time + logger.info( + f"After first-layer validation: {len(response.classes)} classes remain " + f"(validation took {validation_duration:.2f}s)" + ) + + # Check if we have enough classes after first-layer validation + if len(response.classes) < min_classes: + logger.warning( + f"Only {len(response.classes)} classes remain after validation, " + f"which is below minimum of {min_classes}" + ) + + # Step 3: Second-layer validation (OWL semantic validation) + if enable_owl_validation and response.classes: + logger.info("Step 3: Performing second-layer validation (OWL validation)") + owl_start_time = time.time() + + is_valid, errors, world = self.owl_validator.validate_ontology_classes( + classes=response.classes, + ) + + owl_duration = time.time() - owl_start_time + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}" + ) + + # Filter invalid classes based on errors + response = self._filter_invalid_classes( + response=response, + errors=errors, + ) + + logger.info( + f"After second-layer validation: {len(response.classes)} classes remain" + ) + else: + logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s") + else: + if not enable_owl_validation: + logger.info("Step 3: OWL validation disabled, skipping") + else: + logger.info("Step 3: No classes to validate, skipping OWL validation") + + # Calculate total duration + total_duration = time.time() - start_time + + # Log extraction statistics + logger.info( + f"Ontology extraction completed - " + f"final_class_count={len(response.classes)}, " + f"domain={response.domain}, " + f"total_duration={total_duration:.2f}s, " + f"llm_duration={llm_duration:.2f}s" + ) + + return response + + except asyncio.TimeoutError: + # Re-raise timeout errors + total_duration = time.time() - start_time + logger.error( + f"Ontology extraction timed out after {timeout} seconds " + f"(total duration: {total_duration:.2f}s)", + exc_info=True + ) + raise + except Exception as e: + total_duration = time.time() - start_time + logger.error( + f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}", + exc_info=True + ) + # Return empty response on failure + return OntologyExtractionResponse( + classes=[], + domain=domain or "Unknown", + ) + + async def _call_llm_for_extraction( + self, + scenario: str, + domain: Optional[str], + max_classes: int, + llm_temperature: float, + llm_max_tokens: int, + ) -> OntologyExtractionResponse: + """Call LLM to extract ontology classes from scenario. + + This method renders the extraction prompt using the Jinja2 template + and calls the LLM with structured output to get ontology classes. + + Args: + scenario: Scenario description text + domain: Optional domain hint + max_classes: Maximum number of classes to extract + llm_temperature: LLM temperature parameter + llm_max_tokens: LLM max tokens parameter + + Returns: + OntologyExtractionResponse from LLM + + Raises: + Exception: If LLM call fails + """ + try: + # Render prompt using template + prompt_content = await render_ontology_extraction_prompt( + scenario=scenario, + domain=domain, + max_classes=max_classes, + json_schema=OntologyExtractionResponse.model_json_schema(), + ) + + logger.debug(f"Rendered prompt length: {len(prompt_content)}") + + # Create messages for LLM + messages = [ + { + "role": "system", + "content": ( + "You are an expert ontology engineer specializing in knowledge " + "representation and OWL standards. Extract ontology classes from " + "scenario descriptions following the provided instructions. " + "Return valid JSON conforming to the schema." + ), + }, + { + "role": "user", + "content": prompt_content, + }, + ] + + # Call LLM with structured output + logger.debug( + f"Calling LLM with temperature={llm_temperature}, " + f"max_tokens={llm_max_tokens}" + ) + + response = await self.llm_client.response_structured( + messages=messages, + response_model=OntologyExtractionResponse, + ) + + logger.info( + f"LLM extraction successful - extracted {len(response.classes)} classes" + ) + + return response + + except Exception as e: + logger.error( + f"LLM extraction failed: {str(e)}", + exc_info=True + ) + raise + + def _validate_and_clean( + self, + response: OntologyExtractionResponse, + max_description_length: int, + ) -> OntologyExtractionResponse: + """Perform first-layer validation: string validation and cleaning. + + This method validates and cleans the extracted ontology classes: + 1. Validate class names (PascalCase, no reserved words) + 2. Sanitize invalid class names + 3. Truncate long descriptions + 4. Remove duplicate classes + + Args: + response: OntologyExtractionResponse from LLM + max_description_length: Maximum description length + + Returns: + Cleaned OntologyExtractionResponse + """ + if not response.classes: + logger.debug("No classes to validate") + return response + + logger.debug(f"Validating {len(response.classes)} classes") + + validated_classes = [] + + for ontology_class in response.classes: + # Validate class name + is_valid, error_msg = self.validator.validate_class_name( + ontology_class.name + ) + + if not is_valid: + logger.warning( + f"Invalid class name '{ontology_class.name}': {error_msg}" + ) + + # Attempt to sanitize + sanitized_name = self.validator.sanitize_class_name( + ontology_class.name + ) + + logger.info( + f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'" + ) + + # Update class name + ontology_class.name = sanitized_name + + # Re-validate sanitized name + is_valid, error_msg = self.validator.validate_class_name( + sanitized_name + ) + + if not is_valid: + logger.error( + f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. " + "Skipping this class." + ) + continue + + # Truncate description if too long + if ontology_class.description: + original_length = len(ontology_class.description) + ontology_class.description = self.validator.truncate_description( + ontology_class.description, + max_length=max_description_length, + ) + + if len(ontology_class.description) < original_length: + logger.debug( + f"Truncated description for '{ontology_class.name}': " + f"{original_length} -> {len(ontology_class.description)} chars" + ) + + validated_classes.append(ontology_class) + + # Remove duplicates (case-insensitive) + original_count = len(validated_classes) + validated_classes = self.validator.remove_duplicates(validated_classes) + + if len(validated_classes) < original_count: + logger.info( + f"Removed {original_count - len(validated_classes)} duplicate classes" + ) + + # Return cleaned response + return OntologyExtractionResponse( + classes=validated_classes, + domain=response.domain, + ) + + def _filter_invalid_classes( + self, + response: OntologyExtractionResponse, + errors: List[str], + ) -> OntologyExtractionResponse: + """Filter invalid classes based on OWL validation errors. + + This method analyzes OWL validation errors and removes classes + that caused validation failures (e.g., circular inheritance, + inconsistencies). + + Args: + response: OntologyExtractionResponse to filter + errors: List of error messages from OWL validation + + Returns: + Filtered OntologyExtractionResponse + """ + if not errors: + return response + + logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors") + + # Extract class names mentioned in errors + invalid_class_names = set() + + for error in errors: + # Look for class names in error messages + for ontology_class in response.classes: + if ontology_class.name in error: + invalid_class_names.add(ontology_class.name) + logger.debug( + f"Class '{ontology_class.name}' marked as invalid due to error: {error}" + ) + + # Filter out invalid classes + if invalid_class_names: + original_count = len(response.classes) + + filtered_classes = [ + c for c in response.classes + if c.name not in invalid_class_names + ] + + logger.info( + f"Filtered out {original_count - len(filtered_classes)} invalid classes: " + f"{invalid_class_names}" + ) + + return OntologyExtractionResponse( + classes=filtered_classes, + domain=response.domain, + ) + + return response diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index bfc0bc88..8c3e31b4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -25,6 +25,15 @@ class TripletExtractor: """ self.llm_client = llm_client + def _get_language(self) -> str: + """Get the configured language for entity descriptions + + Returns: + Language code ("zh" or "en") + """ + from app.core.config import settings + return settings.DEFAULT_LANGUAGE + async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse: """Process a single statement and return extracted triplets and entities""" # Render the prompt using helper function @@ -40,7 +49,8 @@ class TripletExtractor: statement=statement.statement, chunk_content=chunk_content, json_schema=TripletExtractionResponse.model_json_schema(), - predicate_instructions=PREDICATE_DEFINITIONS + predicate_instructions=PREDICATE_DEFINITIONS, + language=self._get_language() ) # Create messages for LLM diff --git a/api/app/core/memory/utils/alias_utils.py b/api/app/core/memory/utils/alias_utils.py index df75752a..ff139128 100644 --- a/api/app/core/memory/utils/alias_utils.py +++ b/api/app/core/memory/utils/alias_utils.py @@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li key=lambda eid: ( _strength_rank(eid), len(getattr(entity_by_id.get(eid), 'description', '') or ''), - len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '') + 0 # 临时占位 ), reverse=True ) diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 50593e49..a4d2af95 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -177,7 +177,7 @@ def render_entity_dedup_prompt( # Args: # entity_a: Dict of entity A attributes -async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str: +async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str: """ Renders the triplet extraction prompt using the extract_triplet.jinja2 template. @@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j chunk_content: The content of the chunk to process json_schema: JSON schema for the expected output format predicate_instructions: Optional predicate instructions + language: The language to use for entity descriptions ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string @@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j statement=statement, chunk_content=chunk_content, json_schema=json_schema, - predicate_instructions=predicate_instructions + predicate_instructions=predicate_instructions, + language=language ) # 记录渲染结果到提示日志(与示例日志结构一致) log_prompt_rendering('triplet extraction', rendered_prompt) @@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j 'statement': 'str', 'chunk_content': 'str', 'json_schema': 'TripletExtractionResponse.schema', - 'predicate_instructions': 'PREDICATE_DEFINITIONS' + 'predicate_instructions': 'PREDICATE_DEFINITIONS', + 'language': language }) return rendered_prompt @@ -213,6 +216,7 @@ async def render_memory_summary_prompt( chunk_texts: str, json_schema: dict, max_words: int = 200, + language: str = "zh", ) -> str: """ Renders the memory summary prompt using the memory_summary.jinja2 template. @@ -221,6 +225,7 @@ async def render_memory_summary_prompt( chunk_texts: Concatenated text of conversation chunks json_schema: JSON schema for the expected output format max_words: Maximum words for the summary + language: The language to use for summary generation ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string. @@ -230,12 +235,14 @@ async def render_memory_summary_prompt( chunk_texts=chunk_texts, json_schema=json_schema, max_words=max_words, + language=language, ) log_prompt_rendering('memory summary', rendered_prompt) log_template_rendering('memory_summary.jinja2', { 'chunk_texts_len': len(chunk_texts or ""), 'max_words': max_words, - 'json_schema': 'MemorySummaryResponse.schema' + 'json_schema': 'MemorySummaryResponse.schema', + 'language': language }) return rendered_prompt @@ -388,24 +395,65 @@ async def render_memory_insight_prompt( return rendered_prompt -async def render_episodic_title_and_type_prompt(content: str) -> str: +async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str: """ Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template. Args: content: The content of the episodic memory summary to analyze + language: The language to use for title generation ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string """ template = prompt_env.get_template("episodic_type_classification.jinja2") - rendered_prompt = template.render(content=content) + rendered_prompt = template.render(content=content, language=language) # 记录渲染结果到提示日志 log_prompt_rendering('episodic title and type classification', rendered_prompt) # 可选:记录模板渲染信息 log_template_rendering('episodic_type_classification.jinja2', { - 'content_len': len(content) if content else 0 + 'content_len': len(content) if content else 0, + 'language': language + }) + + return rendered_prompt + + +async def render_ontology_extraction_prompt( + scenario: str, + domain: str | None = None, + max_classes: int = 15, + json_schema: dict | None = None +) -> str: + """ + Renders the ontology extraction prompt using the extract_ontology.jinja2 template. + + Args: + scenario: The scenario description text to extract ontology classes from + domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education") + max_classes: Maximum number of classes to extract (default: 15) + json_schema: JSON schema for the expected output format + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("extract_ontology.jinja2") + rendered_prompt = template.render( + scenario=scenario, + domain=domain, + max_classes=max_classes, + json_schema=json_schema + ) + + # 记录渲染结果到提示日志 + log_prompt_rendering('ontology extraction', rendered_prompt) + # 可选:记录模板渲染信息 + log_template_rendering('extract_ontology.jinja2', { + 'scenario_len': len(scenario) if scenario else 0, + 'domain': domain, + 'max_classes': max_classes, + 'json_schema': 'OntologyExtractionResponse.schema' }) return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 index be53c9d4..7fb465a2 100644 --- a/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 @@ -9,7 +9,8 @@ - 类型: "{{ entity_a.entity_type | default('') }}" - 描述: "{{ entity_a.description | default('') }}" - 别名: {{ entity_a.aliases | default([]) }} -- 摘要: "{{ entity_a.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_a.connect_strength | default('') }}" 实体B: @@ -17,7 +18,8 @@ - 类型: "{{ entity_b.entity_type | default('') }}" - 描述: "{{ entity_b.description | default('') }}" - 别名: {{ entity_b.aliases | default([]) }} -- 摘要: "{{ entity_b.fact_summary | default('') }}" +{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #} +{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #} - 连接强弱: "{{ entity_b.connect_strength | default('') }}" 上下文: diff --git a/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 index fa382ec7..d778890b 100644 --- a/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 @@ -1,8 +1,19 @@ === Task === Generate a concise title and classify the episodic memory into the most appropriate category. +{% if language == "zh" %} +**重要:请使用中文生成标题和分类。** +{% else %} +**Important: Please generate the title and classification in English.** +{% endif %} + === Requirements === - Extract a clear, concise title (10-20 characters) that captures the core content +{% if language == "zh" %} +- 标题必须使用中文 +{% else %} +- Title must be in English +{% endif %} - Classify into exactly one category based on the primary theme - Be specific and avoid ambiguity - Output must be valid JSON conforming to the schema below diff --git a/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 new file mode 100644 index 00000000..80594ad9 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 @@ -0,0 +1,210 @@ +===Task=== +Extract ontology classes from the given scenario description following ontology engineering standards. + +===Role=== +You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances. + +===Scenario Description=== +{{ scenario }} + +{% if domain -%} +===Domain Hint=== +This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes. +{%- endif %} + +===Extraction Rules=== + +**1. Abstract Classes, Not Instances:** +- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis") +- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15") +- Think in terms of "types of things" rather than "specific things" + +**2. Naming Convention (PascalCase):** +- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces +- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest" +- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test" +- Use clear, descriptive names in English +- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA") +- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试") + +**3. Domain Relevance:** +- Focus on classes that are central to the scenario's domain +- Prioritize classes that represent key concepts, entities, or relationships +- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning + +**4. Class Quantity:** +- Extract between 5 and {{ max_classes }} classes +- Aim for a balanced set covering the main concepts in the scenario +- Quality over quantity: prefer well-defined classes over exhaustive lists + +**5. Clear Descriptions:** +- Provide concise, informative descriptions in Chinese (max 500 characters) +- Describe what the class represents, not specific instances +- Use clear, natural Chinese language that explains the class's role in the domain + +**6. Concrete Examples:** +- Provide 2-5 concrete instance examples in Chinese for each class +- Examples should be specific, realistic instances of the class +- Examples help clarify the class's scope and meaning +- Use natural Chinese language for examples +- Example format: ["示例1", "示例2", "示例3"] + +**7. Class Hierarchy:** +- Identify parent-child relationships where applicable +- Use the parent_class field to specify inheritance +- Parent class must be one of the extracted classes or a standard OWL class +- Leave parent_class as null for top-level classes + +**8. Entity Types:** +- Classify each class with an appropriate entity_type +- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role" +- Choose the most specific type that applies + +**9. OWL Reserved Words:** +- Do NOT use OWL reserved words as class names +- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal" +- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class") + +**10. Language Consistency:** +- Extract all class names in English (PascalCase format) for the "name" field +- Provide Chinese translation for class names in the "name_chinese" field +- Descriptions MUST be in Chinese (中文) +- Examples MUST be in Chinese (中文) +- Use clear, natural Chinese language for descriptions and examples + +===Examples=== + +**Example 1 (Healthcare Domain):** +Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments." + +Output: +{ + "classes": [ + { + "name": "Patient", + "name_chinese": "患者", + "description": "在医疗机构接受医疗护理或治疗的人", + "examples": ["张三", "李四", "患有糖尿病的老年患者"], + "parent_class": null, + "entity_type": "Person", + "domain": "Healthcare" + }, + { + "name": "MedicalProcedure", + "name_chinese": "医疗程序", + "description": "为医疗诊断或治疗而执行的系统性操作流程", + "examples": ["手术", "血液检查", "X光检查", "疫苗接种"], + "parent_class": null, + "entity_type": "Process", + "domain": "Healthcare" + }, + { + "name": "Diagnosis", + "name_chinese": "诊断", + "description": "基于症状和检查结果对疾病或状况的识别", + "examples": ["糖尿病诊断", "癌症诊断", "流感诊断"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Healthcare" + }, + { + "name": "Doctor", + "name_chinese": "医生", + "description": "诊断和治疗患者的持证医疗专业人员", + "examples": ["全科医生", "外科医生", "心脏病专家"], + "parent_class": null, + "entity_type": "Role", + "domain": "Healthcare" + }, + { + "name": "Treatment", + "name_chinese": "治疗", + "description": "为治愈或管理疾病状况而提供的医疗护理或疗法", + "examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"], + "parent_class": null, + "entity_type": "Process", + "domain": "Healthcare" + } + ], + "domain": "Healthcare", + "namespace": "http://example.org/healthcare#" +} + +**Example 2 (Education Domain):** +Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees." + +Output: +{ + "classes": [ + { + "name": "Student", + "name_chinese": "学生", + "description": "在教育机构注册学习的人", + "examples": ["本科生", "研究生", "在职学生"], + "parent_class": null, + "entity_type": "Role", + "domain": "Education" + }, + { + "name": "Course", + "name_chinese": "课程", + "description": "涵盖特定学科或主题的结构化教育课程", + "examples": ["计算机科学导论", "微积分I", "世界历史"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Education" + }, + { + "name": "Professor", + "name_chinese": "教授", + "description": "教授课程并进行研究的学术教师", + "examples": ["助理教授", "副教授", "正教授"], + "parent_class": null, + "entity_type": "Role", + "domain": "Education" + }, + { + "name": "AcademicProgram", + "name_chinese": "学术项目", + "description": "通向学位或证书的结构化课程体系", + "examples": ["理学学士", "文学硕士", "博士项目"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Education" + }, + { + "name": "Assignment", + "name_chinese": "作业", + "description": "分配给学生以评估学习成果的任务或项目", + "examples": ["论文", "习题集", "研究报告", "实验报告"], + "parent_class": null, + "entity_type": "Object", + "domain": "Education" + }, + { + "name": "Lecture", + "name_chinese": "讲座", + "description": "由教师进行的教育性演讲或讲座", + "examples": ["入门讲座", "客座讲座", "在线讲座"], + "parent_class": null, + "entity_type": "Event", + "domain": "Education" + } + ], + "domain": "Education", + "namespace": "http://example.org/education#" +} + +===Output Format=== + +**JSON Requirements:** +- Use only ASCII double quotes (") for JSON structure +- Never use Chinese quotation marks ("") or Unicode quotes +- Escape quotation marks in text with backslashes (\") +- Ensure proper string closure and comma separation +- No line breaks within JSON string values +- All class names must be in PascalCase format +- All class names must be unique (case-insensitive) +- Extract between 5 and {{ max_classes }} classes + +{{ json_schema }} diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 03691a04..67df162a 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -5,6 +5,12 @@ ===Task=== Extract entities and knowledge triplets from the given statement. +{% if language == "zh" %} +**重要:请使用中文生成实体描述(description)和示例(example)。** +{% else %} +**Important: Please generate entity descriptions and examples in English.** +{% endif %} + ===Inputs=== **Chunk Content:** "{{ chunk_content }}" **Statement:** "{{ statement }}" @@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement. **Entity Extraction:** - Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification +{% if language == "zh" %} +- **实体描述(description)必须使用中文** +- **示例(example)必须使用中文** +{% else %} +- **Entity descriptions must be in English** +- **Examples must be in English** +{% endif %} - **Semantic Memory Classification (is_explicit_memory):** * Set to `true` if the entity represents **explicit/semantic memory**: - **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主" @@ -334,9 +347,11 @@ Output: - Escape quotation marks in text with backslashes (\") - Ensure proper string closure and comma separation - No line breaks within JSON string values -- The output language should ALWAYS match the input language -- If input is in English, extract statements in English -- If input is in Chinese, extract statements in Chinese +{% if language == "zh" %} +- **语言要求:实体描述(description)和示例(example)必须使用中文** +{% else %} +- **Language Requirement: Entity descriptions and examples must be in English** +{% endif %} - Preserve the original language and do not translate {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 index 1dd86ca3..82f91cc4 100644 --- a/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 @@ -5,10 +5,21 @@ === Task === Summarize the provided conversation chunks into a concise Memory summary. +{% if language == "zh" %} +**重要:请使用中文生成摘要内容。** +{% else %} +**Important: Please generate the summary content in English.** +{% endif %} + === Requirements === - Focus on factual statements, user preferences, relationships, and salient temporal context. - Avoid repetition and filler; be specific. - Keep it under {{ max_words or 200 }} words. +{% if language == "zh" %} +- 摘要内容必须使用中文 +{% else %} +- Summary content must be in English +{% endif %} - Output must be valid JSON conforming to the schema below. === Input === @@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary. 4. Do not include line breaks within JSON string values 5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\"" -The output language should always be the same as the input language. +{% if language == "zh" %} +**语言要求:输出内容必须使用中文。** +{% else %} +**Language Requirement: The output content must be in English.** +{% endif %} + Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below: {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/validation/__init__.py b/api/app/core/memory/utils/validation/__init__.py new file mode 100644 index 00000000..d5dd41e7 --- /dev/null +++ b/api/app/core/memory/utils/validation/__init__.py @@ -0,0 +1,10 @@ +"""Validation utilities for ontology extraction. + +This module provides validation classes for ontology class names, +descriptions, and OWL compliance checking. +""" + +from .ontology_validator import OntologyValidator +from .owl_validator import OWLValidator + +__all__ = ['OntologyValidator', 'OWLValidator'] diff --git a/api/app/core/memory/utils/validation/ontology_validator.py b/api/app/core/memory/utils/validation/ontology_validator.py new file mode 100644 index 00000000..eb7492ad --- /dev/null +++ b/api/app/core/memory/utils/validation/ontology_validator.py @@ -0,0 +1,268 @@ +"""String validation for ontology class names and descriptions. + +This module provides the OntologyValidator class for validating and sanitizing +ontology class names according to OWL standards and naming conventions. + +Classes: + OntologyValidator: Validates class names, removes duplicates, and truncates descriptions +""" + +import logging +import re +from typing import List, Tuple + +from app.core.memory.models.ontology_models import OntologyClass + + +logger = logging.getLogger(__name__) + + +class OntologyValidator: + """Validator for ontology class names and descriptions. + + This validator performs string-level validation including: + - PascalCase naming convention validation + - OWL reserved word checking + - Duplicate class name removal + - Description length truncation + + Attributes: + OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names + """ + + # OWL reserved words that cannot be used as class names + OWL_RESERVED_WORDS = { + 'Thing', 'Nothing', 'Class', 'Property', + 'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty', + 'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty', + 'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty', + 'Restriction', 'Ontology', 'Individual', 'NamedIndividual', + 'Annotation', 'AnnotationProperty', 'Axiom', + 'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties', + 'Datatype', 'DataRange', 'Literal', + 'DeprecatedClass', 'DeprecatedProperty', + 'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo', + 'BackwardCompatibleWith', 'OntologyProperty', + } + + def validate_class_name(self, name: str) -> Tuple[bool, str]: + """Validate that a class name follows OWL naming conventions. + + Validation rules: + 1. Must not be empty + 2. Must start with an uppercase letter (PascalCase) + 3. Cannot contain spaces + 4. Can only contain alphanumeric characters and underscores + 5. Cannot be an OWL reserved word + + Args: + name: The class name to validate + + Returns: + Tuple of (is_valid, error_message) + - is_valid: True if the name is valid, False otherwise + - error_message: Empty string if valid, error description if invalid + + Examples: + >>> validator = OntologyValidator() + >>> validator.validate_class_name("MedicalProcedure") + (True, "") + >>> validator.validate_class_name("medical procedure") + (False, "Class name 'medical procedure' cannot contain spaces") + >>> validator.validate_class_name("Thing") + (False, "Class name 'Thing' is an OWL reserved word") + """ + logger.debug(f"Validating class name: '{name}'") + + # Check if empty + if not name or not name.strip(): + error_msg = "Class name cannot be empty" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + name = name.strip() + + # Check if it's an OWL reserved word + if name in self.OWL_RESERVED_WORDS: + error_msg = f"Class name '{name}' is an OWL reserved word" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check if starts with uppercase letter + if not name[0].isupper(): + error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check for spaces + if ' ' in name: + error_msg = f"Class name '{name}' cannot contain spaces" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check for invalid characters (only alphanumeric and underscore allowed) + if not re.match(r'^[A-Za-z0-9_]+$', name): + error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + logger.debug(f"Class name '{name}' is valid") + return True, "" + + def sanitize_class_name(self, name: str) -> str: + """Attempt to sanitize an invalid class name into a valid format. + + Sanitization steps: + 1. Strip whitespace + 2. Remove invalid characters + 3. Replace spaces with empty string (PascalCase) + 4. Capitalize first letter of each word + 5. If result is empty or starts with number, prefix with 'Class' + + Args: + name: The class name to sanitize + + Returns: + Sanitized class name that should pass validation + + Examples: + >>> validator = OntologyValidator() + >>> validator.sanitize_class_name("medical procedure") + 'MedicalProcedure' + >>> validator.sanitize_class_name("patient-record") + 'PatientRecord' + >>> validator.sanitize_class_name("123invalid") + 'Class123Invalid' + """ + logger.debug(f"Sanitizing class name: '{name}'") + + if not name or not name.strip(): + logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'") + return "UnnamedClass" + + # Strip whitespace + name = name.strip() + original_name = name + + # Split on spaces, hyphens, and underscores, then capitalize each word + words = re.split(r'[\s\-_]+', name) + + # Capitalize first letter of each word and keep rest as is + sanitized_words = [] + for word in words: + if word: + # Remove non-alphanumeric characters except underscore + clean_word = re.sub(r'[^A-Za-z0-9_]', '', word) + if clean_word: + # Capitalize first letter + sanitized_words.append(clean_word[0].upper() + clean_word[1:]) + + # Join words + sanitized = ''.join(sanitized_words) + + # If empty or starts with number, prefix with 'Class' + if not sanitized or sanitized[0].isdigit(): + sanitized = 'Class' + sanitized + logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'") + + # If it's a reserved word, append 'Class' suffix + if sanitized in self.OWL_RESERVED_WORDS: + sanitized = sanitized + 'Class' + logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'") + + logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'") + return sanitized + + def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]: + """Remove duplicate ontology classes based on case-insensitive name comparison. + + When duplicates are found, keeps the first occurrence and discards subsequent ones. + Comparison is case-insensitive to catch variations like 'Patient' and 'patient'. + + Args: + classes: List of OntologyClass objects + + Returns: + List of OntologyClass objects with duplicates removed + + Examples: + >>> validator = OntologyValidator() + >>> classes = [ + ... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"), + ... ] + >>> unique = validator.remove_duplicates(classes) + >>> len(unique) + 2 + >>> [c.name for c in unique] + ['Patient', 'Doctor'] + """ + if not classes: + logger.debug("No classes to check for duplicates") + return classes + + logger.debug(f"Checking {len(classes)} classes for duplicates") + + seen_names = set() + unique_classes = [] + duplicates_found = [] + + for ontology_class in classes: + # Use lowercase for comparison + name_lower = ontology_class.name.lower() + + if name_lower not in seen_names: + seen_names.add(name_lower) + unique_classes.append(ontology_class) + else: + duplicates_found.append(ontology_class.name) + logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'") + + if duplicates_found: + logger.info( + f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}" + ) + else: + logger.debug("No duplicate classes found") + + return unique_classes + + def truncate_description(self, description: str, max_length: int = 500) -> str: + """Truncate a description to a maximum length. + + If the description exceeds max_length, it will be truncated and + an ellipsis (...) will be appended to indicate truncation. + + Args: + description: The description text to truncate + max_length: Maximum allowed length (default: 500) + + Returns: + Truncated description string + + Examples: + >>> validator = OntologyValidator() + >>> long_desc = "A" * 600 + >>> truncated = validator.truncate_description(long_desc, max_length=500) + >>> len(truncated) + 500 + >>> truncated.endswith("...") + True + """ + if not description: + return "" + + if len(description) <= max_length: + return description + + # Truncate and add ellipsis + # Reserve 3 characters for "..." + truncate_at = max_length - 3 + truncated = description[:truncate_at] + "..." + + logger.debug( + f"Truncated description from {len(description)} to {len(truncated)} characters" + ) + + return truncated diff --git a/api/app/core/memory/utils/validation/owl_validator.py b/api/app/core/memory/utils/validation/owl_validator.py new file mode 100644 index 00000000..2398d528 --- /dev/null +++ b/api/app/core/memory/utils/validation/owl_validator.py @@ -0,0 +1,585 @@ +"""OWL semantic validation for ontology classes using Owlready2. + +This module provides the OWLValidator class for validating ontology classes +against OWL standards using the Owlready2 library. It performs semantic +validation including consistency checking, circular inheritance detection, +and OWL file export. + +Classes: + OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats +""" + +import logging +from typing import List, Optional, Tuple + +from owlready2 import ( + World, + Thing, + get_ontology, + sync_reasoner_pellet, + OwlReadyInconsistentOntologyError, +) + +from app.core.memory.models.ontology_models import OntologyClass +logger = logging.getLogger(__name__) + + +class OWLValidator: + """Validator for OWL semantic validation of ontology classes. + + This validator performs semantic-level validation using Owlready2 including: + - Creating OWL classes from ontology class definitions + - Running consistency checking with Pellet reasoner + - Detecting circular inheritance + - Validating Protégé compatibility + - Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples) + + Attributes: + base_namespace: Base URI for the ontology namespace + """ + + def __init__(self, base_namespace: str = "http://example.org/ontology#"): + """Initialize the OWL validator. + + Args: + base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#) + """ + self.base_namespace = base_namespace + + def validate_ontology_classes( + self, + classes: List[OntologyClass], + ) -> Tuple[bool, List[str], Optional[World]]: + """Validate extracted ontology classes against OWL standards. + + This method creates an OWL ontology from the provided classes using Owlready2, + runs consistency checking with the Pellet reasoner, and detects common issues + like circular inheritance. + + Args: + classes: List of OntologyClass objects to validate + + Returns: + Tuple of (is_valid, error_messages, world): + - is_valid: True if ontology is valid and consistent, False otherwise + - error_messages: List of error/warning messages + - world: Owlready2 World object containing the ontology (None if validation failed) + + Examples: + >>> validator = OWLValidator() + >>> classes = [ + ... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"), + ... ] + >>> is_valid, errors, world = validator.validate_ontology_classes(classes) + >>> is_valid + True + >>> len(errors) + 0 + """ + if not classes: + return False, ["No classes provided for validation"], None + + errors = [] + + try: + # Create a new world (isolated ontology environment) + world = World() + + # Use a proper ontology IRI + # Owlready2 expects the IRI to end with .owl or similar + onto_iri = self.base_namespace.rstrip('#/') + if not onto_iri.endswith('.owl'): + onto_iri = onto_iri + '.owl' + + # Create ontology + onto = world.get_ontology(onto_iri) + + with onto: + # Dictionary to store created OWL classes for parent reference + owl_classes = {} + + # First pass: Create all classes without parent relationships + for ontology_class in classes: + try: + # Create OWL class dynamically using type() with Thing as base + # The key is to NOT set namespace in the dict, let Owlready2 handle it + owl_class = type( + ontology_class.name, # Class name + (Thing,), # Base classes + {} # Class dict (empty, let Owlready2 manage) + ) + + # Add label (rdfs:label) - include both English and Chinese names + labels = [ontology_class.name] + if ontology_class.name_chinese: + labels.append(ontology_class.name_chinese) + owl_class.label = labels + + # Add comment (rdfs:comment) with description + if ontology_class.description: + owl_class.comment = [ontology_class.description] + + # Store for parent relationship setup + owl_classes[ontology_class.name] = owl_class + + logger.debug( + f"Created OWL class: {ontology_class.name} " + f"(Chinese: {ontology_class.name_chinese}) " + f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}" + ) + + except Exception as e: + error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}" + errors.append(error_msg) + logger.error(error_msg, exc_info=True) + + # Second pass: Set up parent relationships + for ontology_class in classes: + if ontology_class.parent_class and ontology_class.name in owl_classes: + parent_name = ontology_class.parent_class + + # Check if parent exists + if parent_name in owl_classes: + try: + child_class = owl_classes[ontology_class.name] + parent_class = owl_classes[parent_name] + + # Set parent by modifying is_a + child_class.is_a = [parent_class] + + logger.debug( + f"Set parent relationship: {ontology_class.name} -> {parent_name}" + ) + + except Exception as e: + error_msg = ( + f"Failed to set parent relationship " + f"'{ontology_class.name}' -> '{parent_name}': {str(e)}" + ) + errors.append(error_msg) + logger.warning(error_msg) + else: + warning_msg = ( + f"Parent class '{parent_name}' not found for '{ontology_class.name}'" + ) + errors.append(warning_msg) + logger.warning(warning_msg) + + # Check for circular inheritance + for class_name, owl_class in owl_classes.items(): + if self._has_circular_inheritance(owl_class): + error_msg = f"Circular inheritance detected for class '{class_name}'" + errors.append(error_msg) + logger.error(error_msg) + + # Run consistency checking with Pellet reasoner + try: + logger.info("Running Pellet reasoner for consistency checking...") + sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True) + logger.info("Consistency check passed") + + except OwlReadyInconsistentOntologyError as e: + error_msg = f"Ontology is inconsistent: {str(e)}" + errors.append(error_msg) + logger.error(error_msg) + return False, errors, world + + except Exception as e: + # Reasoner errors are often due to Java not being installed or configured + # Log as warning but don't fail validation - ontology structure is still valid + warning_msg = f"Reasoner check skipped: {str(e)}" + if str(e).strip(): # Only log if there's an actual error message + logger.warning(warning_msg) + else: + logger.warning("Reasoner check skipped: Java may not be installed or configured") + # Continue - ontology structure is valid even without reasoner check + + # If we have errors (excluding warnings), validation failed + is_valid = len(errors) == 0 + + return is_valid, errors, world + + except Exception as e: + error_msg = f"OWL validation failed: {str(e)}" + errors.append(error_msg) + logger.error(error_msg, exc_info=True) + return False, errors, None + + def _has_circular_inheritance(self, owl_class) -> bool: + """Check if an OWL class has circular inheritance. + + Circular inheritance occurs when a class inherits from itself through + a chain of parent relationships (e.g., A -> B -> C -> A). + + Args: + owl_class: Owlready2 class object to check + + Returns: + True if circular inheritance is detected, False otherwise + """ + visited = set() + current = owl_class + + while current: + # Get class IRI or name as identifier + class_id = str(current.iri) if hasattr(current, 'iri') else str(current) + + if class_id in visited: + # Found a cycle + return True + + visited.add(class_id) + + # Get parent classes (is_a relationship) + parents = getattr(current, 'is_a', []) + + # Filter out Thing and other base classes + parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')] + + if not parent_classes: + # No more parents, no cycle + break + + # Check first parent (in single inheritance) + current = parent_classes[0] if parent_classes else None + + return False + + def export_to_owl( + self, + world: World, + output_path: Optional[str] = None, + format: str = "rdfxml", + classes: Optional[List] = None + ) -> str: + """Export ontology to OWL file in specified format. + + Supported formats: + - rdfxml: RDF/XML format (default, most compatible) + - turtle: Turtle format (more readable) + - ntriples: N-Triples format (simplest) + - json: JSON format (simplified, human-readable) + + Args: + world: Owlready2 World object containing the ontology + output_path: Optional file path to save the ontology (if None, returns string) + format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml") + classes: Optional list of OntologyClass objects (required for json format) + + Returns: + String representation of the exported ontology + + Raises: + ValueError: If format is not supported + RuntimeError: If export fails + + Examples: + >>> validator = OWLValidator() + >>> is_valid, errors, world = validator.validate_ontology_classes(classes) + >>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml") + """ + # Validate format + valid_formats = ["rdfxml", "turtle", "ntriples", "json"] + if format not in valid_formats: + raise ValueError( + f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}" + ) + + # JSON format doesn't need OWL processing + if format == "json": + if not classes: + raise ValueError("Classes list is required for JSON format export") + return self._export_to_json(classes) + + # For OWL formats, world is required + if not world: + raise ValueError("World object is None. Cannot export ontology.") + + # Note: Owlready2 has issues with turtle format export + # We'll handle it specially by converting from rdfxml + use_conversion = (format == "turtle") + + try: + # Get all ontologies in the world + ontologies = list(world.ontologies.values()) + + if not ontologies: + raise RuntimeError("No ontologies found in world") + + # Find the ontology with classes (skip anonymous/empty ontologies) + onto = None + for ont in ontologies: + classes_count = len(list(ont.classes())) + logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes") + if classes_count > 0: + onto = ont + break + + # If no ontology with classes found, use the last non-anonymous one + if onto is None: + for ont in reversed(ontologies): + if ont.base_iri != "http://anonymous/": + onto = ont + break + + # If still no ontology, use the first one + if onto is None: + onto = ontologies[0] + + # Log ontology contents for debugging + logger.info(f"Ontology IRI: {onto.base_iri}") + logger.info(f"Ontology contains {len(list(onto.classes()))} classes") + + # List all classes in the ontology + all_classes = list(onto.classes()) + for cls in all_classes: + logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})") + if hasattr(cls, 'label'): + logger.debug(f" Labels: {cls.label}") + if hasattr(cls, 'comment'): + logger.debug(f" Comments: {cls.comment}") + + if len(all_classes) == 0: + logger.warning("No classes found in ontology! This may indicate a problem with class creation.") + + if output_path: + # Save to file + export_format = "rdfxml" if use_conversion else format + logger.info(f"Exporting ontology to {output_path} in {export_format} format") + onto.save(file=output_path, format=export_format) + + # Read back the file content to return + with open(output_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Convert to turtle if needed + if use_conversion: + content = self._convert_to_turtle(content) + + logger.info(f"Successfully exported ontology to {output_path}") + + # Format the content for better readability + content = self._format_owl_content(content, format) + + return content + else: + # Export to string (save to temporary location and read) + import tempfile + import os + + with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp: + tmp_path = tmp.name + + try: + export_format = "rdfxml" if use_conversion else format + onto.save(file=tmp_path, format=export_format) + + with open(tmp_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Convert to turtle if needed + if use_conversion: + content = self._convert_to_turtle(content) + + # Format the content for better readability + content = self._format_owl_content(content, format) + + return content + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + + except Exception as e: + error_msg = f"Failed to export ontology: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def _export_to_json(self, classes: List) -> str: + """Export ontology classes to simplified JSON format. + + This format is more compact and easier to parse than OWL XML. + + Args: + classes: List of OntologyClass objects + + Returns: + JSON string representation (compact format) + """ + import json + + result = { + "ontology": { + "namespace": self.base_namespace, + "classes": [] + } + } + + for cls in classes: + class_data = { + "name": cls.name, + "name_chinese": cls.name_chinese, + "description": cls.description, + "entity_type": cls.entity_type, + "domain": cls.domain, + "parent_class": cls.parent_class, + "examples": cls.examples if hasattr(cls, 'examples') else [] + } + result["ontology"]["classes"].append(class_data) + + # 使用紧凑格式:无缩进,使用分隔符减少空格 + return json.dumps(result, ensure_ascii=False, separators=(',', ':')) + + def _convert_to_turtle(self, rdfxml_content: str) -> str: + """Convert RDF/XML content to Turtle format using rdflib. + + Args: + rdfxml_content: RDF/XML format content + + Returns: + Turtle format content + """ + try: + from rdflib import Graph + + # Parse RDF/XML + g = Graph() + g.parse(data=rdfxml_content, format="xml") + + # Serialize to Turtle + turtle_content = g.serialize(format="turtle") + + # Handle bytes vs string + if isinstance(turtle_content, bytes): + turtle_content = turtle_content.decode('utf-8') + + return turtle_content + + except ImportError: + logger.warning( + "rdflib is not installed. Cannot convert to Turtle format. " + "Install with: pip install rdflib" + ) + return rdfxml_content + except Exception as e: + logger.error(f"Failed to convert to Turtle format: {e}") + return rdfxml_content + + def _format_owl_content(self, content: str, format: str) -> str: + """Format OWL content for better readability. + + Args: + content: Raw OWL content string + format: Format type (rdfxml, turtle, ntriples) + + Returns: + Formatted OWL content string + """ + if format == "rdfxml": + # Format XML with proper indentation + try: + import xml.dom.minidom as minidom + dom = minidom.parseString(content) + # Pretty print with 2-space indentation + formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8") + + # Remove extra blank lines + lines = [] + prev_blank = False + for line in formatted.split('\n'): + is_blank = not line.strip() + if not (is_blank and prev_blank): # Skip consecutive blank lines + lines.append(line) + prev_blank = is_blank + + formatted = '\n'.join(lines) + + return formatted + except Exception as e: + logger.warning(f"Failed to format XML content: {e}") + return content + + elif format == "turtle": + # Turtle format is already relatively readable + # Just ensure consistent line endings and not empty + if not content or content.strip() == "": + logger.warning("Turtle content is empty, this may indicate an export issue") + return content.strip() + '\n' if content.strip() else content + + elif format == "ntriples": + # N-Triples format is line-based, ensure proper line endings + return content.strip() + '\n' if content.strip() else content + + return content + + def validate_with_protege_compatibility( + self, + classes: List[OntologyClass] + ) -> Tuple[bool, List[str]]: + """Validate that ontology classes are compatible with Protégé editor. + + Protégé compatibility checks: + - Class names are valid OWL identifiers + - No special characters that Protégé cannot handle + - Namespace is properly formatted + - Labels and comments are properly encoded + + Args: + classes: List of OntologyClass objects to validate + + Returns: + Tuple of (is_compatible, warnings): + - is_compatible: True if compatible with Protégé, False otherwise + - warnings: List of compatibility warning messages + + Examples: + >>> validator = OWLValidator() + >>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")] + >>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes) + >>> is_compatible + True + """ + warnings = [] + + # Check namespace format + if not self.base_namespace.startswith(('http://', 'https://')): + warnings.append( + f"Namespace '{self.base_namespace}' should start with http:// or https:// " + "for Protégé compatibility" + ) + + if not self.base_namespace.endswith(('#', '/')): + warnings.append( + f"Namespace '{self.base_namespace}' should end with # or / " + "for Protégé compatibility" + ) + + # Check each class + for ontology_class in classes: + # Check for special characters that might cause issues + if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']): + warnings.append( + f"Class name '{ontology_class.name}' contains special characters " + "that may cause issues in Protégé" + ) + + # Check description length (Protégé can handle long descriptions but may display poorly) + if ontology_class.description and len(ontology_class.description) > 1000: + warnings.append( + f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) " + "which may display poorly in Protégé" + ) + + # Check for non-ASCII characters (Protégé supports them but encoding issues may occur) + if not ontology_class.name.isascii(): + warnings.append( + f"Class name '{ontology_class.name}' contains non-ASCII characters " + "which may cause encoding issues in some Protégé versions" + ) + + # If no warnings, it's compatible + is_compatible = len(warnings) == 0 + + return is_compatible, warnings diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index c02ca2cb..00b23fb2 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -285,7 +285,7 @@ models: - stream-tool-call logo: dashscope - name: qwen-vl-max - type: llm + type: chat provider: dashscope description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false @@ -298,7 +298,7 @@ models: - video logo: dashscope - name: qwen-vl-plus-0809 - type: llm + type: chat provider: dashscope description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃 is_deprecated: true @@ -311,7 +311,7 @@ models: - video logo: dashscope - name: qwen-vl-plus-2025-01-02 - type: llm + type: chat provider: dashscope description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃 is_deprecated: false @@ -324,7 +324,7 @@ models: - video logo: dashscope - name: qwen-vl-plus-2025-01-25 - type: llm + type: chat provider: dashscope description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false @@ -337,7 +337,7 @@ models: - video logo: dashscope - name: qwen-vl-plus-latest - type: llm + type: chat provider: dashscope description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false @@ -350,7 +350,7 @@ models: - video logo: dashscope - name: qwen-vl-plus - type: llm + type: chat provider: dashscope description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false @@ -616,7 +616,7 @@ models: - audio logo: dashscope - name: qwen3-vl-235b-a22b-instruct - type: llm + type: chat provider: dashscope description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false @@ -631,7 +631,7 @@ models: - video logo: dashscope - name: qwen3-vl-235b-a22b-thinking - type: llm + type: chat provider: dashscope description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false @@ -646,7 +646,7 @@ models: - video logo: dashscope - name: qwen3-vl-30b-a3b-instruct - type: llm + type: chat provider: dashscope description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false @@ -661,7 +661,7 @@ models: - video logo: dashscope - name: qwen3-vl-30b-a3b-thinking - type: llm + type: chat provider: dashscope description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false @@ -676,7 +676,7 @@ models: - video logo: dashscope - name: qwen3-vl-flash - type: llm + type: chat provider: dashscope description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false @@ -691,7 +691,7 @@ models: - video logo: dashscope - name: qwen3-vl-plus-2025-09-23 - type: llm + type: chat provider: dashscope description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false @@ -704,7 +704,7 @@ models: - video logo: dashscope - name: qwen3-vl-plus - type: llm + type: chat provider: dashscope description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 1f696c98..65fbd9cb 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +import logging +logger = logging.getLogger(__name__) def knowledge_retrieval( query: str, @@ -62,7 +64,15 @@ def knowledge_retrieval( merge_strategy = config.get("merge_strategy", "weight") reranker_id = config.get("reranker_id") reranker_top_k = config.get("reranker_top_k", 1024) - use_graph = config.get("use_graph", "false").lower() == "true" + # use_graph = config.get("use_graph", "false").lower() == "true" + + use_graph_value = config.get("use_graph", False) + if isinstance(use_graph_value, bool): + use_graph = use_graph_value + elif isinstance(use_graph_value, str): + use_graph = use_graph_value.lower() in ("true", "1", "yes") + else: + use_graph = False file_names_filter = [] if user_ids: @@ -159,13 +169,29 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) - # use graph + try: + return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + except Exception as rerank_error: + # If reranker fails, log warning and continue with original results + logger.warning( + "Reranker failed, falling back to original results", + extra={ + "reranker_id": reranker_id, + "query": query, + "doc_count": len(all_results), + "error": str(rerank_error), + }, + ) + if use_graph: - from app.core.rag.common.settings import kg_retriever - doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) - if doc: - all_results.insert(0, doc) + try: + from app.core.rag.common.settings import kg_retriever + doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + all_results.insert(0, doc) + except Exception as graph_error: + print(f"Failed to retrieve from knowledge graph: {str(graph_error)}") + return all_results except Exception as e: diff --git a/api/app/core/workflow/nodes/code/config.py b/api/app/core/workflow/nodes/code/config.py index 8af13f12..a47586a3 100644 --- a/api/app/core/workflow/nodes/code/config.py +++ b/api/app/core/workflow/nodes/code/config.py @@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig): description="code content" ) - language: Literal['python3', 'nodejs'] = Field( + language: Literal['python3', 'javascript'] = Field( ..., description="language" ) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index b2a4da32..daee1e78 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -2,6 +2,7 @@ import base64 import json import logging import re +import urllib.parse from string import Template from textwrap import dedent from typing import Any @@ -14,7 +15,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig logger = logging.getLogger(__name__) -SCRIPT_TEMPLATE = Template(dedent(""" +PYTHON_SCRIPT_TEMPLATE = Template(dedent(""" $code import json @@ -32,6 +33,20 @@ result = "<>" + output_json + "<>" print(result) """)) +NODEJS_SCRIPT_TEMPLATE = Template(dedent(""" +$code +// decode and prepare input object +var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8')) + +// execute main function +var output_obj = main(inputs_obj) + +// convert output to json and print +var output_json = JSON.stringify(output_obj) +var result = `<>$${output_json}<>` +console.log(result) +""")) + class CodeNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): @@ -83,18 +98,27 @@ class CodeNode(BaseNode): input_variable_dict = {} for input_variable in self.typed_config.input_variables: input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state) + code = base64.b64decode( self.typed_config.code ).decode("utf-8") + code = urllib.parse.unquote(code, encoding='utf-8') input_variable_dict = base64.b64encode( json.dumps(input_variable_dict).encode("utf-8") ).decode("utf-8") - - final_script = SCRIPT_TEMPLATE.substitute( - code=code, - inputs_variable=input_variable_dict, - ) + if self.typed_config.language == "python3": + final_script = PYTHON_SCRIPT_TEMPLATE.substitute( + code=code, + inputs_variable=input_variable_dict, + ) + elif self.typed_config.language == 'javascript': + final_script = NODEJS_SCRIPT_TEMPLATE.substitute( + code=code, + inputs_variable=input_variable_dict, + ) + else: + raise ValueError(f"Unsupported language: {self.typed_config.language}") async with httpx.AsyncClient() as client: response = await client.post( diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index ec58d96c..079cd4cc 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None @staticmethod def _get_prompt(): @@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode): ]) model_resp = await llm.ainvoke(messages) + self.response_metadata = model_resp.response_metadata result = json_repair.repair_json(model_resp.content, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 6df410cb..8076dc9d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + self.response_metadata = {} + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + if self.response_metadata: + usage = self.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode): response = await llm.ainvoke(messages) result = response.content.strip() + self.response_metadata = response.response_metadata if result in category_names: category = result diff --git a/api/app/core/workflow/template_loader.py b/api/app/core/workflow/template_loader.py index 4ef49ba5..ef16cf74 100644 --- a/api/app/core/workflow/template_loader.py +++ b/api/app/core/workflow/template_loader.py @@ -4,16 +4,19 @@ 从文件系统加载预定义的工作流模板 """ +import os from pathlib import Path from typing import Optional import yaml +TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') + class TemplateLoader: """工作流模板加载器""" - - def __init__(self, templates_dir: str = "app/templates/workflows"): + + def __init__(self, templates_dir: str = TEMPLATE_DIR): """初始化模板加载器 Args: @@ -22,7 +25,7 @@ class TemplateLoader: self.templates_dir = Path(templates_dir) if not self.templates_dir.exists(): raise ValueError(f"模板目录不存在: {templates_dir}") - + def list_templates(self) -> list[dict]: """列出所有可用的模板 @@ -30,22 +33,22 @@ class TemplateLoader: 模板列表,每个模板包含 id, name, description 等信息 """ templates = [] - + # 遍历模板目录 for template_dir in self.templates_dir.iterdir(): if not template_dir.is_dir(): continue - + # 检查是否有 template.yml 文件 template_file = template_dir / "template.yml" if not template_file.exists(): continue - + try: # 读取模板配置 with open(template_file, 'r', encoding='utf-8') as f: template_data = yaml.safe_load(f) - + # 提取模板信息 templates.append({ "id": template_dir.name, @@ -59,9 +62,9 @@ class TemplateLoader: except Exception as e: print(f"加载模板 {template_dir.name} 失败: {e}") continue - + return templates - + def load_template(self, template_id: str) -> Optional[dict]: """加载指定的模板 @@ -73,14 +76,14 @@ class TemplateLoader: """ template_dir = self.templates_dir / template_id template_file = template_dir / "template.yml" - + if not template_file.exists(): return None - + try: with open(template_file, 'r', encoding='utf-8') as f: template_data = yaml.safe_load(f) - + # 返回工作流配置部分 return { "name": template_data.get("name", template_id), @@ -94,7 +97,7 @@ class TemplateLoader: except Exception as e: print(f"加载模板 {template_id} 失败: {e}") return None - + def get_template_readme(self, template_id: str) -> Optional[str]: """获取模板的 README 文档 @@ -106,10 +109,10 @@ class TemplateLoader: """ template_dir = self.templates_dir / template_id readme_file = template_dir / "README.md" - + if not readme_file.exists(): return None - + try: with open(readme_file, 'r', encoding='utf-8') as f: return f.read() diff --git a/api/app/templates/workflows/customer_service/template.yml b/api/app/core/workflow/templates/customer_service/template.yml similarity index 100% rename from api/app/templates/workflows/customer_service/template.yml rename to api/app/core/workflow/templates/customer_service/template.yml diff --git a/api/app/templates/workflows/data_processing/template.yml b/api/app/core/workflow/templates/data_processing/template.yml similarity index 100% rename from api/app/templates/workflows/data_processing/template.yml rename to api/app/core/workflow/templates/data_processing/template.yml diff --git a/api/app/templates/workflows/multi_step_qa/template.yml b/api/app/core/workflow/templates/multi_step_qa/template.yml similarity index 100% rename from api/app/templates/workflows/multi_step_qa/template.yml rename to api/app/core/workflow/templates/multi_step_qa/template.yml diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/core/workflow/templates/simple_qa/template.yml similarity index 100% rename from api/app/templates/workflows/simple_qa/template.yml rename to api/app/core/workflow/templates/simple_qa/template.yml diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index a429dd8e..984212de 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -28,6 +28,10 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel +from .ontology_scene import OntologyScene +from .ontology_class import OntologyClass +from .ontology_scene import OntologyScene +from .ontology_class import OntologyClass __all__ = [ "Tenants", diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 454b1b48..8a451f2d 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -20,6 +20,9 @@ class MemoryConfig(Base): end_user_id = Column(String, nullable=True, comment="组ID") user_id = Column(String, nullable=True, comment="用户ID") apply_id = Column(String, nullable=True, comment="应用ID") + + # 本体场景关联 + scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID,关联ontology_scene表") # 模型选择(从workspace继承) llm_id = Column(String, nullable=True, comment="LLM模型配置ID") diff --git a/api/app/models/ontology_class.py b/api/app/models/ontology_class.py new file mode 100644 index 00000000..528d934e --- /dev/null +++ b/api/app/models/ontology_class.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +"""本体类型模型 + +本模块定义本体类型的数据模型。 + +Classes: + OntologyClass: 本体类型表模型 +""" + +import datetime +import uuid +from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from app.db import Base + + +class OntologyClass(Base): + """本体类型表 - 用于存储某个场景提取出来的本体类型信息""" + __tablename__ = "ontology_class" + + # 主键 + class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID") + + # 类型信息 + class_name = Column(String(200), nullable=False, comment="类型名称") + class_description = Column(Text, nullable=True, comment="类型描述") + + # 外键:关联到本体场景 + scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间") + + # 关系:类型属于某个场景 + scene = relationship("OntologyScene", back_populates="classes") + + def __repr__(self): + return f"" diff --git a/api/app/models/ontology_scene.py b/api/app/models/ontology_scene.py new file mode 100644 index 00000000..350bfdd6 --- /dev/null +++ b/api/app/models/ontology_scene.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""本体场景模型 + +本模块定义本体场景的数据模型。 + +Classes: + OntologyScene: 本体场景表模型 +""" + +import datetime +import uuid +from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from app.db import Base + + +class OntologyScene(Base): + """本体场景表 - 用于存储本体场景下不同的类型信息""" + __tablename__ = "ontology_scene" + __table_args__ = ( + UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'), + ) + + # 主键 + scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID") + + # 场景信息 + scene_name = Column(String(200), nullable=False, comment="场景名称") + scene_description = Column(Text, nullable=True, comment="场景描述") + + # 外键:关联到工作空间 + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间") + + # 关系:一个场景可以有多个类型 + classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/api/app/models/prompt_optimizer_model.py b/api/app/models/prompt_optimizer_model.py index 39845ee7..f96b0a66 100644 --- a/api/app/models/prompt_optimizer_model.py +++ b/api/app/models/prompt_optimizer_model.py @@ -2,7 +2,7 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index +from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean from sqlalchemy.dialects.postgresql import UUID from app.db import Base @@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID") # app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID") - session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID") + session_id = Column( + UUID(as_uuid=True), + ForeignKey("prompt_opt_session_list.id"), + nullable=False, + comment="Session ID" + ) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID") role = Column(String, nullable=False, comment="Message Role") content = Column(Text, nullable=False, comment="Message Content") # prompt = Column(Text, nullable=False, comment="Prompt") created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True) + + +class PromptHistory(Base): + __tablename__ = "prompt_history" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID") + + session_id = Column( + UUID(as_uuid=True), + ForeignKey("prompt_opt_session_list.id"), + nullable=False, + comment="Session ID" + ) + title = Column(String, nullable=False, comment="Title") + prompt = Column(Text, nullable=False, comment="Prompt") + created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True) + is_delete = Column(Boolean, default=False, comment="Delete") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index c20c79c1..419f7624 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -86,7 +86,8 @@ class MemoryConfigRepository: n.description AS description, n.entity_type AS entity_type, n.name AS name, - COALESCE(n.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(n.fact_summary, '') AS fact_summary, n.end_user_id AS end_user_id, n.apply_id AS apply_id, n.user_id AS user_id, @@ -156,7 +157,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -230,9 +231,12 @@ class MemoryConfigRepository: config_name=params.config_name, config_desc=params.config_desc, workspace_id=params.workspace_id, + scene_id=params.scene_id, llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, + reflection_model_id=params.reflection_model_id, + emotion_model_id=params.emotion_model_id, ) db.add(db_config) db.flush() # 获取自增ID但不提交事务 @@ -275,6 +279,9 @@ class MemoryConfigRepository: if update.config_desc is not None: db_config.config_desc = update.config_desc has_update = True + if update.scene_id is not None: + db_config.scene_id = update.scene_id + has_update = True if not has_update: raise ValueError("No fields to update") @@ -643,28 +650,32 @@ class MemoryConfigRepository: raise @staticmethod - def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: - """获取所有配置参数 + def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]: + """获取所有配置参数,包含关联的场景名称 Args: db: 数据库会话 workspace_id: 工作空间ID,用于过滤查询结果 Returns: - List[MemoryConfig]: 配置列表 + List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ + from app.models.ontology_scene import OntologyScene + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: - query = db.query(MemoryConfig) + query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin( + OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id + ) if workspace_id: query = query.filter(MemoryConfig.workspace_id == workspace_id) - configs = query.order_by(desc(MemoryConfig.updated_at)).all() + results = query.order_by(desc(MemoryConfig.updated_at)).all() - db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") - return configs + db_logger.debug(f"配置列表查询成功: 数量={len(results)}") + return results except Exception as e: db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}") diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 162bf411..2b32551c 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], try: edges: List[dict] = [] for s in summaries: - for chunk_id in getattr(s, "chunk_ids", []) or []: + chunk_ids = getattr(s, "chunk_ids", []) or [] + for chunk_id in chunk_ids: edges.append({ "summary_id": s.id, "chunk_id": chunk_id, @@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], if not edges: return [] - result = await connector.execute_query( MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, edges=edges ) created = [record.get("uuid") for record in result] if result else [] return created - except Exception: + except Exception as e: return None diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index fcf700b5..42c178b3 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids - except Exception: + except Exception as e: + print(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c93e75b3..651c513f 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity e.name_embedding = CASE WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding ELSE e.name_embedding END, - e.fact_summary = CASE - WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' - AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) - THEN entity.fact_summary ELSE e.fact_summary END, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // e.fact_summary = CASE + // WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' + // AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) + // THEN entity.fact_summary ELSE e.fact_summary END, e.connect_strength = CASE WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength ELSE CASE @@ -321,7 +322,8 @@ RETURN e.id AS id, e.description AS description, e.aliases AS aliases, e.name_embedding AS name_embedding, - COALESCE(e.fact_summary, '') AS fact_summary, + // TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + // COALESCE(e.fact_summary, '') AS fact_summary, e.connect_strength AS connect_strength, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT c.id) AS chunk_ids, @@ -877,7 +879,8 @@ RETURN CASE WHEN ms:ExtractedEntity THEN { text: ms.name, - created_at: ms.created_at + created_at: ms.created_at, + type: "情景记忆" } END ) AS ExtractedEntity, @@ -887,7 +890,8 @@ RETURN CASE WHEN n:MemorySummary THEN { text: n.content, - created_at: n.created_at + created_at: n.created_at, + type: "长期沉淀" } END ) AS MemorySummary, @@ -895,7 +899,8 @@ RETURN collect( DISTINCT { text: e.statement, - created_at: e.created_at + created_at: e.created_at, + type: "情绪记忆" } ) AS statement; """ @@ -999,3 +1004,58 @@ RETURN DISTINCT x.statement as statement,x.created_at as created_at """ +Graph_Node_query = """ + MATCH (n:MemorySummary) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Dialogue) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT 1 + + UNION ALL + + MATCH (n:Statement) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:ExtractedEntity) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority + LIMIT $limit + + UNION ALL + + MATCH (n:Chunk) + WHERE n.end_user_id = $end_user_id + RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority + LIMIT $limit + + """ \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 1575315f..5099fd01 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, EntityEntityEdge, ) - +import logging +logger = logging.getLogger(__name__) async def save_entities_and_relationships( entity_nodes: List[ExtractedEntityNode], entity_entity_edges: List[EntityEntityEdge], @@ -41,8 +42,8 @@ async def save_entities_and_relationships( 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, } @@ -147,14 +148,14 @@ async def save_statement_entity_edges( async def save_dialog_and_statements_to_neo4j( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - entity_edges: List[EntityEntityEdge], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + entity_edges: List[EntityEntityEdge], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -171,40 +172,127 @@ async def save_dialog_and_statements_to_neo4j( Returns: bool: True if successful, False otherwise """ - try: - # Save all dialogue nodes in batch - dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) - if dialogue_uuids: + + # 定义事务函数,将所有写操作放在一个事务中 + async def _save_all_in_transaction(tx): + """在单个事务中执行所有保存操作,避免死锁""" + results = {} + + # 1. Save all dialogue nodes in batch + if dialogue_nodes: + from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE + dialogue_data = [node.model_dump() for node in dialogue_nodes] + result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) + dialogue_uuids = [record["uuid"] async for record in result] + results['dialogues'] = dialogue_uuids print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") - else: - print("Failed to save dialogues to Neo4j") - return False - # Save all chunk nodes in batch - await save_chunk_nodes(chunk_nodes, connector) + # 2. Save all chunk nodes in batch + if chunk_nodes: + from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE + chunk_data = [node.model_dump() for node in chunk_nodes] + result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) + chunk_uuids = [record["uuid"] async for record in result] + results['chunks'] = chunk_uuids + logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") - # Save all statement nodes in batch + # 3. Save all statement nodes in batch if statement_nodes: - statement_uuids = await add_statement_nodes(statement_nodes, connector) - if statement_uuids: - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - else: - print("Failed to save statement nodes to Neo4j") - return False - else: - print("No statement nodes to save") + from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE + statement_data = [node.model_dump() for node in statement_nodes] + result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) + statement_uuids = [record["uuid"] async for record in result] + results['statements'] = statement_uuids + logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - # Save entities and relationships - await save_entities_and_relationships(entity_nodes, entity_edges, connector) - print("Successfully saved entities and relationships to Neo4j") + # 4. Save entities + if entity_nodes: + from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE + entity_data = [entity.model_dump() for entity in entity_nodes] + result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) + entity_uuids = [record["uuid"] async for record in result] + results['entities'] = entity_uuids + logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") - # Save new edges - await save_statement_chunk_edges(statement_chunk_edges, connector) - await save_statement_entity_edges(statement_entity_edges, connector) + # 5. Create entity relationships + if entity_edges: + from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE + relationship_data = [] + for edge in entity_edges: + relationship_data.append({ + 'source_id': edge.source, + 'target_id': edge.target, + 'predicate': edge.relation_type, + 'statement_id': edge.source_statement_id, + 'value': edge.relation_value, + 'statement': edge.statement, + 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, + 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, + 'created_at': edge.created_at.isoformat() if edge.created_at else None, + 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, + 'run_id': edge.run_id, + 'end_user_id': edge.end_user_id, + }) + result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) + rel_uuids = [record["uuid"] async for record in result] + results['entity_relationships'] = rel_uuids + logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") + # 6. Save statement-chunk edges + if statement_chunk_edges: + from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE + sc_edge_data = [] + for edge in statement_chunk_edges: + sc_edge_data.append({ + "id": edge.id, + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + }) + result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data) + sc_uuids = [record["uuid"] async for record in result] + results['statement_chunk_edges'] = sc_uuids + logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") + + # 7. Save statement-entity edges + if statement_entity_edges: + from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE + se_edge_data = [] + for edge in statement_entity_edges: + se_edge_data.append({ + "source": edge.source, + "target": edge.target, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, + "run_id": edge.run_id, + "end_user_id": edge.end_user_id, + "connect_strength": getattr(edge, "connect_strength", "strong"), + }) + result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data) + se_uuids = [record["uuid"] async for record in result] + results['statement_entity_edges'] = se_uuids + logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + + return results + + try: + # 使用显式写事务执行所有操作,避免死锁 + results = await connector.execute_write_transaction(_save_all_in_transaction) + summary = { + key: len(value) + for key, value in results.items() + if isinstance(value, (list, tuple, set)) + } + logger.info("Transaction completed. Summary: %s", summary) + logger.debug("Full transaction results: %r", results) return True except Exception as e: + logger.error(f"Neo4j integration error: {e}", exc_info=True) print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False + diff --git a/api/app/repositories/ontology_class_repository.py b/api/app/repositories/ontology_class_repository.py new file mode 100644 index 00000000..68f261ff --- /dev/null +++ b/api/app/repositories/ontology_class_repository.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +"""本体类型Repository层 + +本模块提供本体类型的数据访问层实现。 + +Classes: + OntologyClassRepository: 本体类型数据访问类 +""" + +import logging +from typing import List, Optional +from uuid import UUID + +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger +from app.models.ontology_class import OntologyClass +from app.models.ontology_scene import OntologyScene + + +logger = get_db_logger() + + +class OntologyClassRepository: + """本体类型Repository + + 提供本体类型的CRUD操作和权限检查。 + + Attributes: + db: SQLAlchemy数据库会话 + """ + + def __init__(self, db: Session): + """初始化Repository + + Args: + db: SQLAlchemy数据库会话 + """ + self.db = db + + def create(self, class_data: dict, scene_id: UUID) -> OntologyClass: + """创建本体类型 + + Args: + class_data: 类型数据字典,包含class_name和class_description + scene_id: 所属场景ID + + Returns: + OntologyClass: 创建的类型对象 + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.create( + ... {"class_name": "患者", "class_description": "描述"}, + ... scene_id + ... ) + """ + try: + logger.info( + f"Creating ontology class - " + f"name={class_data.get('class_name')}, " + f"scene_id={scene_id}" + ) + + ontology_class = OntologyClass( + class_name=class_data.get("class_name"), + class_description=class_data.get("class_description"), + scene_id=scene_id + ) + + self.db.add(ontology_class) + self.db.flush() # 获取ID但不提交 + + logger.info( + f"Ontology class created successfully - " + f"class_id={ontology_class.class_id}" + ) + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to create ontology class: {str(e)}", + exc_info=True + ) + raise + + def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]: + """根据ID获取类型 + + Args: + class_id: 类型ID + + Returns: + Optional[OntologyClass]: 类型对象,不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.get_by_id(class_id) + """ + try: + logger.debug(f"Getting ontology class by ID: {class_id}") + + ontology_class = self.db.query(OntologyClass).filter( + OntologyClass.class_id == class_id + ).first() + + if ontology_class: + logger.debug(f"Ontology class found: {class_id}") + else: + logger.debug(f"Ontology class not found: {class_id}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to get ontology class by ID: {str(e)}", + exc_info=True + ) + raise + + def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]: + """根据类型名称和场景ID获取类型(精确匹配) + + Args: + class_name: 类型名称 + scene_id: 场景ID + + Returns: + Optional[OntologyClass]: 类型对象,不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.get_by_name("患者", scene_id) + """ + try: + logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}") + + ontology_class = self.db.query(OntologyClass).filter( + OntologyClass.class_name == class_name, + OntologyClass.scene_id == scene_id + ).first() + + if ontology_class: + logger.debug(f"Ontology class found: {class_name}") + else: + logger.debug(f"Ontology class not found: {class_name}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to get ontology class by name: {str(e)}", + exc_info=True + ) + raise + + def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]: + """根据关键词模糊搜索类型 + + 使用 LIKE 进行模糊匹配,支持中文和英文。 + + Args: + keyword: 搜索关键词 + scene_id: 场景ID + + Returns: + List[OntologyClass]: 匹配的类型列表 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> classes = repo.search_by_name("患者", scene_id) + """ + try: + logger.debug( + f"Searching ontology classes by keyword - " + f"keyword={keyword}, scene_id={scene_id}" + ) + + # 使用 ilike 进行不区分大小写的模糊匹配 + classes = self.db.query(OntologyClass).filter( + OntologyClass.class_name.ilike(f"%{keyword}%"), + OntologyClass.scene_id == scene_id + ).order_by( + OntologyClass.created_at.desc() + ).all() + + logger.info( + f"Found {len(classes)} ontology classes matching keyword '{keyword}' " + f"in scene {scene_id}" + ) + + return classes + + except Exception as e: + logger.error( + f"Failed to search ontology classes by keyword: {str(e)}", + exc_info=True + ) + raise + + def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]: + """获取场景下的所有类型 + + 按创建时间倒序排列。 + + Args: + scene_id: 场景ID + + Returns: + List[OntologyClass]: 类型列表 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> classes = repo.get_by_scene(scene_id) + """ + try: + logger.debug(f"Getting ontology classes by scene: {scene_id}") + + classes = self.db.query(OntologyClass).filter( + OntologyClass.scene_id == scene_id + ).order_by( + OntologyClass.created_at.desc() + ).all() + + logger.info( + f"Found {len(classes)} ontology classes in scene {scene_id}" + ) + + return classes + + except Exception as e: + logger.error( + f"Failed to get ontology classes by scene: {str(e)}", + exc_info=True + ) + raise + + def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]: + """更新类型信息 + + Args: + class_id: 类型ID + update_data: 更新数据字典 + + Returns: + Optional[OntologyClass]: 更新后的类型对象,不存在则返回None + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.update( + ... class_id, + ... {"class_name": "新名称"} + ... ) + """ + try: + logger.info(f"Updating ontology class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Ontology class not found for update: {class_id}") + return None + + # 更新字段 + if "class_name" in update_data and update_data["class_name"] is not None: + ontology_class.class_name = update_data["class_name"] + + if "class_description" in update_data: + ontology_class.class_description = update_data["class_description"] + + self.db.flush() + + logger.info(f"Ontology class updated successfully: {class_id}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to update ontology class: {str(e)}", + exc_info=True + ) + raise + + def delete(self, class_id: UUID) -> bool: + """删除类型 + + Args: + class_id: 类型ID + + Returns: + bool: 删除成功返回True,类型不存在返回False + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> success = repo.delete(class_id) + """ + try: + logger.info(f"Deleting ontology class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Ontology class not found for delete: {class_id}") + return False + + self.db.delete(ontology_class) + self.db.flush() + + logger.info(f"Ontology class deleted successfully: {class_id}") + + return True + + except Exception as e: + logger.error( + f"Failed to delete ontology class: {str(e)}", + exc_info=True + ) + raise + + def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool: + """检查类型是否属于指定工作空间(通过场景关联) + + Args: + class_id: 类型ID + workspace_id: 工作空间ID + + Returns: + bool: 属于返回True,否则返回False + + Examples: + >>> repo = OntologyClassRepository(db) + >>> is_owner = repo.check_ownership(class_id, workspace_id) + """ + try: + logger.debug( + f"Checking class ownership - " + f"class_id={class_id}, workspace_id={workspace_id}" + ) + + count = self.db.query(OntologyClass).join( + OntologyScene, + OntologyClass.scene_id == OntologyScene.scene_id + ).filter( + OntologyClass.class_id == class_id, + OntologyScene.workspace_id == workspace_id + ).count() + + is_owner = count > 0 + + logger.debug( + f"Class ownership check result: {is_owner} - " + f"class_id={class_id}" + ) + + return is_owner + + except Exception as e: + logger.error( + f"Failed to check class ownership: {str(e)}", + exc_info=True + ) + raise + + def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]: + """根据类型ID获取所属场景ID + + Args: + class_id: 类型ID + + Returns: + Optional[UUID]: 场景ID,类型不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> scene_id = repo.get_scene_id_by_class(class_id) + """ + try: + logger.debug(f"Getting scene ID by class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.debug(f"Class not found: {class_id}") + return None + + logger.debug( + f"Found scene ID: {ontology_class.scene_id} for class: {class_id}" + ) + + return ontology_class.scene_id + + except Exception as e: + logger.error( + f"Failed to get scene ID by class: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py new file mode 100644 index 00000000..141b5d1c --- /dev/null +++ b/api/app/repositories/ontology_scene_repository.py @@ -0,0 +1,439 @@ +# -*- coding: utf-8 -*- +"""本体场景Repository层 + +本模块提供本体场景的数据访问层实现。 + +Classes: + OntologySceneRepository: 本体场景数据访问类 +""" + +import logging +from typing import List, Optional +from uuid import UUID + +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger +from app.models.ontology_scene import OntologyScene + + +logger = get_db_logger() + + +class OntologySceneRepository: + """本体场景Repository + + 提供本体场景的CRUD操作和权限检查。 + + Attributes: + db: SQLAlchemy数据库会话 + """ + + def __init__(self, db: Session): + """初始化Repository + + Args: + db: SQLAlchemy数据库会话 + """ + self.db = db + + def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene: + """创建本体场景 + + Args: + scene_data: 场景数据字典,包含scene_name和scene_description + workspace_id: 所属工作空间ID + + Returns: + OntologyScene: 创建的场景对象 + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.create( + ... {"scene_name": "医疗场景", "scene_description": "描述"}, + ... workspace_id + ... ) + """ + try: + logger.info( + f"Creating ontology scene - " + f"name={scene_data.get('scene_name')}, " + f"workspace_id={workspace_id}" + ) + + scene = OntologyScene( + scene_name=scene_data.get("scene_name"), + scene_description=scene_data.get("scene_description"), + workspace_id=workspace_id + ) + + self.db.add(scene) + self.db.flush() # 获取ID但不提交 + + logger.info( + f"Ontology scene created successfully - " + f"scene_id={scene.scene_id}" + ) + + return scene + + except Exception as e: + logger.error( + f"Failed to create ontology scene: {str(e)}", + exc_info=True + ) + raise + + def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]: + """根据ID获取场景 + + Args: + scene_id: 场景ID + + Returns: + Optional[OntologyScene]: 场景对象,不存在则返回None + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.get_by_id(scene_id) + """ + try: + logger.debug(f"Getting ontology scene by ID: {scene_id}") + + scene = self.db.query(OntologyScene).filter( + OntologyScene.scene_id == scene_id + ).first() + + if scene: + logger.debug(f"Ontology scene found: {scene_id}") + else: + logger.debug(f"Ontology scene not found: {scene_id}") + + return scene + + except Exception as e: + logger.error( + f"Failed to get ontology scene by ID: {str(e)}", + exc_info=True + ) + raise + + def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]: + """根据场景名称和工作空间ID获取场景(精确匹配) + + Args: + scene_name: 场景名称 + workspace_id: 工作空间ID + + Returns: + Optional[OntologyScene]: 场景对象,不存在则返回None + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.get_by_name("医疗场景", workspace_id) + """ + try: + logger.debug( + f"Getting ontology scene by name - " + f"scene_name={scene_name}, workspace_id={workspace_id}" + ) + + scene = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.scene_name == scene_name, + OntologyScene.workspace_id == workspace_id + ).first() + + if scene: + logger.debug(f"Ontology scene found: {scene_name}") + else: + logger.debug(f"Ontology scene not found: {scene_name}") + + return scene + + except Exception as e: + logger.error( + f"Failed to get ontology scene by name: {str(e)}", + exc_info=True + ) + raise + + def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]: + """根据关键词模糊搜索场景 + + 使用 LIKE 进行模糊匹配,支持中文和英文。 + + Args: + keyword: 搜索关键词 + workspace_id: 工作空间ID + + Returns: + List[OntologyScene]: 匹配的场景列表 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.search_by_name("医疗", workspace_id) + """ + try: + logger.debug( + f"Searching ontology scenes by keyword - " + f"keyword={keyword}, workspace_id={workspace_id}" + ) + + # 使用 ilike 进行不区分大小写的模糊匹配 + scenes = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.scene_name.ilike(f"%{keyword}%"), + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + logger.info( + f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' " + f"in workspace {workspace_id}" + ) + + return scenes + + except Exception as e: + logger.error( + f"Failed to search ontology scenes by keyword: {str(e)}", + exc_info=True + ) + raise + + def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple: + """获取工作空间下的所有场景(支持分页) + + 使用joinedload预加载classes关系以统计数量。 + + Args: + workspace_id: 工作空间ID + page: 页码(可选,从1开始) + page_size: 每页数量(可选) + + Returns: + tuple: (场景列表, 总数量) + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes, total = repo.get_by_workspace(workspace_id) + >>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10) + """ + try: + logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}") + + # 构建基础查询 + query = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ) + + # 获取总数 + total = query.count() + + # 如果提供了分页参数,应用分页 + if page is not None and page_size is not None: + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size) + logger.debug(f"Applying pagination: offset={offset}, limit={page_size}") + + scenes = query.all() + + logger.info( + f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}" + ) + + return scenes, total + + except Exception as e: + logger.error( + f"Failed to get ontology scenes by workspace: {str(e)}", + exc_info=True + ) + raise + + def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]: + """更新场景信息 + + Args: + scene_id: 场景ID + update_data: 更新数据字典 + + Returns: + Optional[OntologyScene]: 更新后的场景对象,不存在则返回None + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.update( + ... scene_id, + ... {"scene_name": "新名称"} + ... ) + """ + try: + logger.info(f"Updating ontology scene: {scene_id}") + + scene = self.get_by_id(scene_id) + if not scene: + logger.warning(f"Ontology scene not found for update: {scene_id}") + return None + + # 更新字段 + if "scene_name" in update_data and update_data["scene_name"] is not None: + scene.scene_name = update_data["scene_name"] + + if "scene_description" in update_data: + scene.scene_description = update_data["scene_description"] + + self.db.flush() + + logger.info(f"Ontology scene updated successfully: {scene_id}") + + return scene + + except Exception as e: + logger.error( + f"Failed to update ontology scene: {str(e)}", + exc_info=True + ) + raise + + def delete(self, scene_id: UUID) -> bool: + """删除场景(级联删除类型) + + 依赖数据库级联删除配置(ondelete="CASCADE")。 + + Args: + scene_id: 场景ID + + Returns: + bool: 删除成功返回True,场景不存在返回False + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> success = repo.delete(scene_id) + """ + try: + logger.info(f"Deleting ontology scene: {scene_id}") + + scene = self.get_by_id(scene_id) + if not scene: + logger.warning(f"Ontology scene not found for delete: {scene_id}") + return False + + self.db.delete(scene) + self.db.flush() + + logger.info( + f"Ontology scene deleted successfully (cascade): {scene_id}" + ) + + return True + + except Exception as e: + logger.error( + f"Failed to delete ontology scene: {str(e)}", + exc_info=True + ) + raise + + def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool: + """检查场景是否属于指定工作空间 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID + + Returns: + bool: 属于返回True,否则返回False + + Examples: + >>> repo = OntologySceneRepository(db) + >>> is_owner = repo.check_ownership(scene_id, workspace_id) + """ + try: + logger.debug( + f"Checking scene ownership - " + f"scene_id={scene_id}, workspace_id={workspace_id}" + ) + + count = self.db.query(OntologyScene).filter( + OntologyScene.scene_id == scene_id, + OntologyScene.workspace_id == workspace_id + ).count() + + is_owner = count > 0 + + logger.debug( + f"Scene ownership check result: {is_owner} - " + f"scene_id={scene_id}" + ) + + return is_owner + + except Exception as e: + logger.error( + f"Failed to check scene ownership: {str(e)}", + exc_info=True + ) + raise + + def get_simple_list(self, workspace_id: UUID) -> List[dict]: + """获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择) + + 这是一个轻量级查询,不加载关联的classes,响应速度快。 + + Args: + workspace_id: 工作空间ID + + Returns: + List[dict]: 场景简单列表,每项包含scene_id和scene_name + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.get_simple_list(workspace_id) + >>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...] + """ + try: + logger.debug(f"Getting simple scene list for workspace: {workspace_id}") + + # 只查询需要的字段,不加载关联数据 + results = self.db.query( + OntologyScene.scene_id, + OntologyScene.scene_name + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + scenes = [ + {"scene_id": str(r.scene_id), "scene_name": r.scene_name} + for r in results + ] + + logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}") + + return scenes + + except Exception as e: + logger.error( + f"Failed to get simple scene list: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/repositories/prompt_optimizer_repository.py b/api/app/repositories/prompt_optimizer_repository.py index ba65257a..e73ab513 100644 --- a/api/app/repositories/prompt_optimizer_repository.py +++ b/api/app/repositories/prompt_optimizer_repository.py @@ -4,7 +4,10 @@ from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger from app.models.prompt_optimizer_model import ( - PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType + PromptOptimizerSession, + PromptOptimizerSessionHistory, + RoleType, + PromptHistory ) db_logger = get_db_logger() @@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository: def __init__(self, db: Session): self.db = db + def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None: + session = self.db.query(PromptOptimizerSession).filter( + PromptOptimizerSession.id == session_id, + ).first() + return session + def create_session( self, tenant_id: uuid.UUID, @@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository: user_id=user_id, ) self.db.add(session) - self.db.commit() - self.db.refresh(session) - db_logger.debug(f"Prompt optimization session created: ID:{session.id}") return session except Exception as e: - db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}") + db_logger.error(f"Error creating prompt optimization session: - {str(e)}") raise def get_session_history( @@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository: PromptOptimizerSession.id == session_id, PromptOptimizerSession.user_id == user_id ).first() - + if not session: return [] - + history = self.db.query(PromptOptimizerSessionHistory).filter( PromptOptimizerSessionHistory.session_id == session.id, PromptOptimizerSessionHistory.user_id == user_id @@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository: PromptOptimizerSession.user_id == user_id, PromptOptimizerSession.tenant_id == tenant_id ).first() - + if not session: db_logger.error(f"Session {session_id} not found for user {user_id}") raise ValueError(f"Session {session_id} not found for user {user_id}") - + message = PromptOptimizerSessionHistory( tenant_id=tenant_id, session_id=session.id, @@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository: content=content, ) self.db.add(message) - self.db.commit() + return message except Exception as e: db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}") raise + + def get_first_user_message(self, session_id: uuid.UUID) -> str | None: + """ + Get the first user message from a session. + + Args: + session_id (uuid.UUID): The session ID. + + Returns: + str | None: The content of the first user message, or None if not found. + """ + try: + message = self.db.query(PromptOptimizerSessionHistory).filter( + PromptOptimizerSessionHistory.session_id == session_id, + PromptOptimizerSessionHistory.role == RoleType.USER.value + ).order_by( + PromptOptimizerSessionHistory.created_at.asc() + ).first() + + return message.content if message else None + except Exception as e: + db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}") + raise + + +class PromptReleaseRepository: + def __init__(self, db: Session): + self.db = db + + def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None: + prompt_obj = self.db.query(PromptHistory).filter( + PromptHistory.session_id == session_id, + PromptHistory.is_delete.is_(False) + ).first() + return prompt_obj + + def create_prompt_release( + self, + tenant_id: uuid.UUID, + title: str, + session_id: uuid.UUID, + prompt: str, + ) -> PromptHistory: + try: + prompt_obj = PromptHistory( + tenant_id=tenant_id, + title=title, + session_id=session_id, + prompt=prompt, + ) + self.db.add(prompt_obj) + return prompt_obj + except Exception as e: + db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}") + raise + + def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None: + """ + Soft delete a prompt release by setting is_delete flag to True. + + Args: + prompt_obj (PromptHistory): The prompt release object to delete. + """ + try: + prompt_obj.is_delete = True + db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}") + except Exception as e: + db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}") + raise + + def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None: + """ + Get a prompt release by its ID. + + Args: + prompt_id (uuid.UUID): The prompt release ID. + + Returns: + PromptHistory | None: The prompt release object or None if not found. + """ + try: + prompt_obj = self.db.query(PromptHistory).filter( + PromptHistory.id == prompt_id + ).first() + return prompt_obj + except Exception as e: + db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}") + raise + + def count_prompts(self, tenant_id: uuid.UUID) -> int: + """ + Count total number of non-deleted prompts for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + + Returns: + int: Total count of prompts. + """ + try: + count = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False) + ).count() + return count + except Exception as e: + db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}") + raise + + def get_prompts_paginated( + self, + tenant_id: uuid.UUID, + offset: int, + limit: int + ) -> list[PromptHistory]: + """ + Get paginated list of prompt releases for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + offset (int): Number of records to skip. + limit (int): Maximum number of records to return. + + Returns: + list[PromptHistory]: List of prompt releases. + """ + try: + prompts = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False) + ).order_by( + PromptHistory.created_at.desc() + ).offset(offset).limit(limit).all() + return prompts + except Exception as e: + db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}") + raise + + def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int: + """ + Count total number of non-deleted prompts matching keyword for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + keyword (str): Search keyword for title. + + Returns: + int: Total count of matching prompts. + """ + try: + count = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False), + PromptHistory.title.ilike(f"%{keyword}%") + ).count() + return count + except Exception as e: + db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}") + raise + + def search_prompts_paginated( + self, + tenant_id: uuid.UUID, + keyword: str, + offset: int, + limit: int + ) -> list[PromptHistory]: + """ + Search prompt releases by keyword in title with pagination. + + Args: + tenant_id (uuid.UUID): The tenant ID. + keyword (str): Search keyword for title. + offset (int): Number of records to skip. + limit (int): Maximum number of records to return. + + Returns: + list[PromptHistory]: List of matching prompt releases. + """ + try: + prompts = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False), + PromptHistory.title.ilike(f"%{keyword}%") + ).order_by( + PromptHistory.created_at.desc() + ).offset(offset).limit(limit).all() + return prompts + except Exception as e: + db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}") + raise diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 09410091..ddaed685 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel): kb_id: str = Field(..., description="知识库ID") top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量") similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值") - strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") - weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") + # strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") + # weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重") retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid") diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 5fda0a1d..c3e7295b 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -229,10 +229,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") + # 本体场景关联(可选) + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") + reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") + emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) @@ -243,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") + config_name: Optional[str] = Field(None, description="配置名称(字符串)") + config_desc: Optional[str] = Field(None, description="配置描述(字符串)") + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID") class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 diff --git a/api/app/schemas/ontology_schemas.py b/api/app/schemas/ontology_schemas.py new file mode 100644 index 00000000..5a88f84d --- /dev/null +++ b/api/app/schemas/ontology_schemas.py @@ -0,0 +1,461 @@ +"""本体提取API的请求和响应模型 + +本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。 + +Classes: + ExtractionRequest: 本体提取请求模型 + ExtractionResponse: 本体提取响应模型 + ExportRequest: OWL文件导出请求模型 + ExportResponse: OWL文件导出响应模型 + OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳) + SceneCreateRequest: 场景创建请求模型 + SceneUpdateRequest: 场景更新请求模型 + SceneResponse: 场景响应模型 + SceneListResponse: 场景列表响应模型 + ClassCreateRequest: 类型创建请求模型 + ClassUpdateRequest: 类型更新请求模型 + ClassResponse: 类型响应模型 + ClassListResponse: 类型列表响应模型 +""" + +from typing import List, Optional +import datetime +from uuid import UUID + +from pydantic import BaseModel, Field, field_serializer, ConfigDict + +from app.core.memory.models.ontology_models import OntologyClass + + +class ExtractionRequest(BaseModel): + """本体提取请求模型 + + 用于POST /api/ontology/extract端点的请求体。 + + Attributes: + scenario: 场景描述文本,不能为空 + domain: 可选的领域提示(如Healthcare, Education等) + llm_id: LLM模型ID,必须提供 + scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景 + + Examples: + >>> request = ExtractionRequest( + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... llm_id="550e8400-e29b-41d4-a716-446655440000", + ... scene_id="660e8400-e29b-41d4-a716-446655440000" + ... ) + """ + scenario: str = Field(..., description="场景描述文本", min_length=1) + domain: Optional[str] = Field(None, description="可选的领域提示") + llm_id: str = Field(..., description="LLM模型ID") + scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景") + + +class ExtractionResponse(BaseModel): + """本体提取响应模型 + + 用于POST /api/ontology/extract端点的响应体。 + + Attributes: + classes: 提取的本体类列表 + domain: 识别的领域 + extracted_count: 提取的类数量 + + Examples: + >>> response = ExtractionResponse( + ... classes=[...], + ... domain="Healthcare", + ... extracted_count=7 + ... ) + """ + classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表") + domain: str = Field(..., description="识别的领域") + extracted_count: int = Field(..., description="提取的类数量") + + +class ExportRequest(BaseModel): + """OWL文件导出请求模型 + + 用于POST /api/ontology/export端点的请求体。 + + Attributes: + classes: 要导出的本体类列表 + format: 导出格式,可选值: rdfxml, turtle, ntriples, json + include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True + + Examples: + >>> request = ExportRequest( + ... classes=[...], + ... format="rdfxml", + ... include_metadata=True + ... ) + """ + classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1) + format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json") + include_metadata: bool = Field(True, description="是否包含完整的OWL元数据") + + +class ExportResponse(BaseModel): + """OWL文件导出响应模型 + + 用于POST /api/ontology/export端点的响应体。 + + Attributes: + owl_content: OWL文件内容 + format: 导出格式 + classes_count: 导出的类数量 + + Examples: + >>> response = ExportResponse( + ... owl_content="...", + ... format="rdfxml", + ... classes_count=7 + ... ) + """ + owl_content: str = Field(..., description="OWL文件内容") + format: str = Field(..., description="导出格式") + classes_count: int = Field(..., description="导出的类数量") + + +class OntologyResultResponse(BaseModel): + """本体提取结果响应模型 + + 用于返回数据库中存储的提取结果,时间戳为毫秒级。 + + Attributes: + id: 结果ID (UUID) + scenario: 场景描述文本 + domain: 领域 + classes_json: 提取的本体类数据(JSON格式) + extracted_count: 提取的类数量 + user_id: 用户ID + created_at: 创建时间(毫秒时间戳) + + Examples: + >>> response = OntologyResultResponse( + ... id=uuid.uuid4(), + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... classes_json={"classes": [...]}, + ... extracted_count=7, + ... user_id=123, + ... created_at=datetime.now() + ... ) + """ + id: UUID = Field(..., description="结果ID") + scenario: str = Field(..., description="场景描述文本") + domain: Optional[str] = Field(None, description="领域") + classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)") + extracted_count: int = Field(..., description="提取的类数量") + user_id: Optional[int] = Field(None, description="用户ID") + created_at: datetime.datetime = Field(..., description="创建时间") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + class Config: + from_attributes = True + + + +# ==================== 本体场景相关 Schema ==================== + +class SceneCreateRequest(BaseModel): + """场景创建请求模型 + + 用于创建新的本体场景。 + + Attributes: + scene_name: 场景名称,必填,1-200字符 + scene_description: 场景描述,可选 + + Examples: + >>> request = SceneCreateRequest( + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模" + ... ) + """ + scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + + +class SceneUpdateRequest(BaseModel): + """场景更新请求模型 + + 用于更新已有本体场景信息。 + + Attributes: + scene_name: 场景名称,可选,1-200字符 + scene_description: 场景描述,可选 + + Examples: + >>> request = SceneUpdateRequest( + ... scene_name="更新后的场景名称", + ... scene_description="更新后的描述" + ... ) + """ + scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + + +class SceneResponse(BaseModel): + """场景响应模型 + + 用于返回本体场景信息。 + + Attributes: + scene_id: 场景ID + scene_name: 场景名称 + scene_description: 场景描述 + type_num: 类型数量 + workspace_id: 所属工作空间ID + created_at: 创建时间(毫秒时间戳) + updated_at: 更新时间(毫秒时间戳) + classes_count: 类型数量 + + Examples: + >>> response = SceneResponse( + ... scene_id=uuid.uuid4(), + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模", + ... type_num=0, + ... workspace_id=uuid.uuid4(), + ... created_at=datetime.now(), + ... updated_at=datetime.now(), + ... classes_count=5 + ... ) + """ + scene_id: UUID = Field(..., description="场景ID") + scene_name: str = Field(..., description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + type_num: int = Field(..., description="类型数量") + entity_type: Optional[List[str]] = Field(None, description="实体类型列表(最多3个class_name)") + workspace_id: UUID = Field(..., description="所属工作空间ID") + created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") + updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") + classes_count: int = Field(0, description="类型数量") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + """将更新时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + model_config = ConfigDict(from_attributes=True) + + +class PaginationInfo(BaseModel): + """分页信息模型 + + Attributes: + page: 当前页码 + pagesize: 每页数量 + total: 总数量 + hasnext: 是否有下一页 + """ + page: int = Field(..., description="当前页码") + pagesize: int = Field(..., description="每页数量") + total: int = Field(..., description="总数量") + hasnext: bool = Field(..., description="是否有下一页") + + +class SceneListResponse(BaseModel): + """场景列表响应模型(支持分页) + + 用于返回本体场景列表。 + + Attributes: + items: 场景列表 + page: 分页信息(可选,分页时返回) + + Examples: + >>> # 不分页 + >>> response = SceneListResponse( + ... items=[scene1, scene2] + ... ) + >>> # 分页 + >>> response = SceneListResponse( + ... items=[scene1, scene2, ...], + ... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True) + ... ) + """ + items: List[SceneResponse] = Field(..., description="场景列表") + page: Optional[PaginationInfo] = Field(None, description="分页信息") + + +# ==================== 本体类型相关 Schema ==================== + +class ClassItem(BaseModel): + """单个类型信息模型 + + Attributes: + class_name: 类型名称,必填,1-200字符 + class_description: 类型描述,可选 + + Examples: + >>> item = ClassItem( + ... class_name="患者", + ... class_description="医院患者信息" + ... ) + """ + class_name: str = Field(..., min_length=1, max_length=200, description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + + +class ClassCreateRequest(BaseModel): + """类型创建请求模型(统一使用列表形式) + + 通过列表中元素数量决定创建模式: + - 列表包含 1 个元素:单个创建 + - 列表包含多个元素:批量创建 + + Attributes: + scene_id: 所属场景ID,必填 + classes: 类型列表,必填,至少包含 1 个元素 + + Examples: + # 单个创建(列表中 1 个元素) + >>> request = ClassCreateRequest( + ... scene_id=uuid.uuid4(), + ... classes=[ + ... ClassItem(class_name="患者", class_description="医院患者信息") + ... ] + ... ) + + # 批量创建(列表中多个元素) + >>> request = ClassCreateRequest( + ... scene_id=uuid.uuid4(), + ... classes=[ + ... ClassItem(class_name="患者", class_description="医院患者信息"), + ... ClassItem(class_name="医生", class_description="医院医生信息"), + ... ClassItem(class_name="药品", class_description="医院药品信息") + ... ] + ... ) + """ + scene_id: UUID = Field(..., description="所属场景ID") + classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素") + + +class ClassUpdateRequest(BaseModel): + """类型更新请求模型 + + 用于更新已有本体类型信息。 + + Attributes: + class_name: 类型名称,可选,1-200字符 + class_description: 类型描述,可选 + + Examples: + >>> request = ClassUpdateRequest( + ... class_name="更新后的类型名称", + ... class_description="更新后的描述" + ... ) + """ + class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + + +class ClassResponse(BaseModel): + """类型响应模型 + + 用于返回本体类型信息。 + + Attributes: + class_id: 类型ID + class_name: 类型名称 + class_description: 类型描述 + scene_id: 所属场景ID + created_at: 创建时间(毫秒时间戳) + updated_at: 更新时间(毫秒时间戳) + + Examples: + >>> response = ClassResponse( + ... class_id=uuid.uuid4(), + ... class_name="患者", + ... class_description="医院患者信息", + ... scene_id=uuid.uuid4(), + ... created_at=datetime.now(), + ... updated_at=datetime.now() + ... ) + """ + class_id: UUID = Field(..., description="类型ID") + class_name: str = Field(..., description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + scene_id: UUID = Field(..., description="所属场景ID") + created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") + updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + """将更新时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + model_config = ConfigDict(from_attributes=True) + + +class ClassBatchCreateResponse(BaseModel): + """批量创建类型响应模型 + + 用于返回批量创建的结果统计和详情。 + + Attributes: + total: 总共尝试创建的数量 + success_count: 成功创建的数量 + failed_count: 失败的数量 + items: 成功创建的类型列表 + errors: 失败的错误信息列表(可选) + + Examples: + >>> response = ClassBatchCreateResponse( + ... total=3, + ... success_count=2, + ... failed_count=1, + ... items=[class1, class2], + ... errors=["创建类型 '药品' 失败: 类型名称已存在"] + ... ) + """ + total: int = Field(..., description="总共尝试创建的数量") + success_count: int = Field(..., description="成功创建的数量") + failed_count: int = Field(0, description="失败的数量") + items: List[ClassResponse] = Field(..., description="成功创建的类型列表") + errors: Optional[List[str]] = Field(None, description="失败的错误信息列表") + + +class ClassListResponse(BaseModel): + """类型列表响应模型 + + 用于返回本体类型列表。 + + Attributes: + total: 总数量 + scene_id: 所属场景ID + scene_name: 场景名称 + scene_description: 场景描述 + items: 类型列表 + + Examples: + >>> response = ClassListResponse( + ... total=3, + ... scene_id=uuid.uuid4(), + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模", + ... items=[class1, class2, class3] + ... ) + """ + total: int = Field(..., description="总数量") + scene_id: UUID = Field(..., description="所属场景ID") + scene_name: str = Field(..., description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + items: List[ClassResponse] = Field(..., description="类型列表") diff --git a/api/app/schemas/prompt_optimizer_schema.py b/api/app/schemas/prompt_optimizer_schema.py index e1f27be0..08a11317 100644 --- a/api/app/schemas/prompt_optimizer_schema.py +++ b/api/app/schemas/prompt_optimizer_schema.py @@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel): ) +class PromptSaveRequest(BaseModel): + session_id: UUID = Field( + ..., + description="Session ID" + ) + + title: str = Field( + ..., + description="Prompt Title" + ) + + prompt: str = Field( + ..., + description="Optimized prompt content" + ) + + class PromptOptModelSet(BaseModel): id: UUID | None = Field( default=None, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index c0a66e03..bd9106e5 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -171,7 +171,14 @@ class AppChatService: self.conversation_service.save_conversation_messages( conversation_id=conversation_id, user_message=message, - assistant_message=result["content"] + assistant_message=result["content"], + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) elapsed_time = time.time() - start_time @@ -310,6 +317,7 @@ class AppChatService: # 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -320,9 +328,12 @@ class AppChatService: config_id=config_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" elapsed_time = time.time() - start_time @@ -339,7 +350,7 @@ class AppChatService: content=full_content, meta_data={ "model": api_key_obj.model_name, - "usage": {} + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} } ) @@ -416,7 +427,11 @@ class AppChatService: meta_data={ "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), - "sub_results": result.get("sub_results") + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -458,6 +473,7 @@ class AppChatService: yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" full_content = "" + total_tokens = 0 # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) @@ -474,16 +490,26 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event - # 尝试提取内容(用于保存) - if "data:" in event: - try: - data_line = event.split("data: ", 1)[1].strip() - data = json.loads(data_line) - if "content" in data: - full_content += data["content"] - except: - pass + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + # 尝试提取内容(用于保存) + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass elapsed_time = time.time() - start_time @@ -499,7 +525,12 @@ class AppChatService: role="assistant", content=full_content, meta_data={ - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } ) diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py index c164924a..5cfa3229 100644 --- a/api/app/services/app_statistics_service.py +++ b/api/app/services/app_statistics_service.py @@ -187,7 +187,7 @@ class AppStatisticsService: daily_tokens[date_str] = 0 daily_tokens[date_str] += int(tokens) - daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] - total = sum(row["tokens"] for row in daily_data) + daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] + total = sum(row["count"] for row in daily_data) return {"daily": daily_data, "total": total} diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 275d6413..553aefc4 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -1,4 +1,5 @@ """会话服务""" +import os import uuid from datetime import datetime, timedelta from typing import Annotated @@ -298,7 +299,8 @@ class ConversationService: self, conversation_id: uuid.UUID, user_message: str, - assistant_message: str + assistant_message: str, + meta_data: Optional[dict] = None ): """ Save a pair of user and assistant messages to the conversation. @@ -307,6 +309,7 @@ class ConversationService: conversation_id (uuid.UUID): Conversation UUID. user_message (str): User's message content. assistant_message (str): Assistant's response content. + meta_data (Optional[dict]): Optional metadata for the messages. """ self.add_message( conversation_id=conversation_id, @@ -317,7 +320,8 @@ class ConversationService: self.add_message( conversation_id=conversation_id, role="assistant", - content=assistant_message + content=assistant_message, + meta_data=meta_data ) logger.debug( @@ -526,12 +530,12 @@ class ConversationService: takeaways=[], info_score=0, ) - - with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f: system_prompt = f.read() rendered_system_message = Template(system_prompt).render() - with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f: + with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f: user_prompt = f.read() rendered_user_message = Template(user_prompt).render( language=language, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 524c9ff6..43073555 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) @@ -442,7 +443,14 @@ class DraftRunService: user_message=message, assistant_message=result["content"], app_id=agent_config.app_id, - user_id=user_id + user_id=user_id, + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) response = { @@ -649,6 +657,7 @@ class DraftRunService: # 9. 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -659,14 +668,22 @@ class DraftRunService: user_rag_memory_id=user_rag_memory_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield self._format_sse_event("message", { - "content": chunk - }) + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield self._format_sse_event("message", { + "content": chunk + }) elapsed_time = time.time() - start_time + if sub_agent: + yield self._format_sse_event("sub_usage", { + "total_tokens": total_tokens + }) + # 10. 保存会话消息 if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( @@ -674,7 +691,10 @@ class DraftRunService: user_message=message, assistant_message=full_content, app_id=agent_config.app_id, - user_id=user_id + user_id=user_id, + meta_data={ + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} + } ) # 11. 发送结束事件 @@ -898,6 +918,7 @@ class DraftRunService: conversation_id: str, user_message: str, assistant_message: str, + meta_data: dict, app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None ) -> None: @@ -909,6 +930,7 @@ class DraftRunService: assistant_message: AI 回复消息 app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) + meta_data: token消耗 """ try: from app.services.conversation_service import ConversationService @@ -927,7 +949,8 @@ class DraftRunService: conversation_service.add_message( conversation_id=conv_uuid, role="assistant", - content=assistant_message + content=assistant_message, + meta_data=meta_data ) logger.debug( diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index 114e9945..10e4d646 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -4,7 +4,7 @@ import uuid from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated from typing_extensions import TypedDict -from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk from langgraph.graph import StateGraph, START, END from langgraph.types import Command from langgraph.checkpoint.memory import MemorySaver @@ -727,9 +727,12 @@ class HandoffsService: # 提取响应 response_content = "" + total_tokens = 0 for msg in result.get("messages", []): if isinstance(msg, AIMessage): response_content = msg.content + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break return { @@ -737,7 +740,12 @@ class HandoffsService: "active_agent": result.get("active_agent"), "response": response_content, "message_count": len(result.get("messages", [])), - "handoff_count": result.get("handoff_count", 0) + "handoff_count": result.get("handoff_count", 0), + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } async def chat_stream( @@ -830,6 +838,12 @@ class HandoffsService: # 捕获 LLM 结束事件,输出收集到的工具调用 elif kind == "on_chat_model_end": + output_message = event.get("data", {}).get("output", {}) + if isinstance(output_message, AIMessageChunk): + response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", + 0) if response_meta else 0 + yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n" if collected_tool_calls: # 找到参数最完整的 transfer 工具调用 best_tc = None diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 06a94060..6fa8b228 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -53,7 +53,10 @@ def get_workspace_end_users( workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: - """获取工作空间的所有宿主(优化版本:减少数据库查询次数)""" + """获取工作空间的所有宿主(优化版本:减少数据库查询次数) + + 返回结果按 updated_at 从新到旧排序(NULL 值排在最后) + """ business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") try: @@ -68,9 +71,14 @@ def get_workspace_end_users( app_ids = [app.id for app in apps_orm] # 批量查询所有 end_users(一次查询而非循环查询) + # 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 from app.models.end_user_model import EndUser as EndUserModel + from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( EndUserModel.app_id.in_(app_ids) + ).order_by( + nullslast(desc(EndUserModel.updated_at)), + desc(EndUserModel.id) ).all() # 转换为 Pydantic 模型(只在需要时转换) diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index b92a5d06..e025c1b3 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -89,7 +89,6 @@ class WorkspaceAppService: for release in app_releases: memory_content = self._extract_memory_content(release.config) - memory_content=resolve_config_id(memory_content, self.db) if memory_content and memory_content in processed_configs: continue @@ -122,16 +121,12 @@ class WorkspaceAppService: def _get_memory_config(self, memory_content: str) -> Dict[str, Any]: """Retrieve memory_config information based on memory_content""" try: - memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) - - # memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content) - # memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone() - # if memory_config_result is None: - # return None + memory_content = resolve_config_id(memory_content, self.db) + memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content)) if memory_config_result: return { - "config_id": memory_config_result.config_id, + "config_id": memory_content, "enable_self_reflexion": memory_config_result.enable_self_reflexion, "iteration_period": memory_config_result.iteration_period, "reflexion_range": memory_config_result.reflexion_range, @@ -291,7 +286,7 @@ class MemoryReflectionService: # 检查是否需要执行反思 should_execute = False hours_diff = 0 - + if current_reflection_time is None: # 首次执行反思 should_execute = True @@ -303,11 +298,11 @@ class MemoryReflectionService: reflection_time = datetime.fromisoformat(current_reflection_time) else: reflection_time = current_reflection_time - + current_time = datetime.now() time_diff = current_time - reflection_time hours_diff = int(time_diff.total_seconds() / 3600) - + # 检查是否达到反思周期 if hours_diff >= iteration_period: should_execute = True @@ -317,7 +312,7 @@ class MemoryReflectionService: except (ValueError, TypeError) as e: api_logger.warning(f"解析反思时间失败: {e},将执行反思") should_execute = True - + if should_execute: api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") # 3. 执行反思引擎 @@ -350,7 +345,7 @@ class MemoryReflectionService: "next_reflection_in_hours": iteration_period - hours_diff } - + except Exception as e: config_id = config_data.get("config_id", "unknown") api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") @@ -361,7 +356,7 @@ class MemoryReflectionService: "end_user_id": end_user_id, "config_data": config_data } - + def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig: """Create reflective configuration objects from configuration data""" @@ -369,12 +364,12 @@ class MemoryReflectionService: if reflexion_range_value is None or reflexion_range_value == "": reflexion_range_value = "partial" reflexion_range = ReflectionRange(reflexion_range_value) - + baseline_value = config_data.get("baseline") if baseline_value is None or baseline_value == "": baseline_value = "TIME" baseline = ReflectionBaseline(baseline_value) - + # iteration_period = iteration_period = config_data.get("iteration_period", 24) if isinstance(iteration_period, str): @@ -382,7 +377,6 @@ class MemoryReflectionService: iteration_period = int(iteration_period) except (ValueError, TypeError): iteration_period = 24 # 默认24小时 - return ReflectionConfig( enabled=config_data.get("enable_self_reflexion", False), iteration_period=str(iteration_period), # ReflectionConfig期望字符串 diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index eec1007b..d3d267be 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.rerank_id: params.rerank_id = configs.get('rerank') + # reflection_model_id 和 emotion_model_id 默认与 llm_id 一致 + if not params.reflection_model_id: + params.reflection_model_id = params.llm_id + if not params.emotion_model_id: + params.emotion_model_id = params.llm_id + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -177,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - configs = MemoryConfigRepository.get_all(self.db, workspace_id) + results = MemoryConfigRepository.get_all(self.db, workspace_id) # 将 ORM 对象转换为字典列表 data_list = [] - for config in configs: + for config, scene_name in results: # 安全地转换 user_id 为 int config_id_old = None if config.config_id_old: @@ -203,6 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, + "scene_id": str(config.scene_id) if config.scene_id else None, + "scene_name": scene_name, # 新增:场景名称 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, @@ -628,10 +636,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: if m < 1: latest_relative = "刚刚" elif m < 60: - latest_relative = f"{m}分钟前" + latest_relative = "一会前" else: - h = int(m // 60) - latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前" + latest_relative = "较早前" except Exception: pass diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d9062eaf..b28bafbf 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -280,14 +280,22 @@ class MultiAgentOrchestrator: # 4. 提取子 Agent 的 conversation_id(用于多轮对话) sub_conversation_id = None + total_tokens = 0 + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") + # 提取 token 信息 + usage = results.get("usage", {}) or results.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) elif isinstance(results, list) and results: for item in results: if "result" in item: sub_conversation_id = item["result"].get("conversation_id") if sub_conversation_id: break + # 累加每个子 Agent 的 token + usage = item.get("usage", {}) or item.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) logger.info( "多 Agent 任务完成", @@ -301,9 +309,15 @@ class MultiAgentOrchestrator: return { "message": final_result, "conversation_id": sub_conversation_id, + "mode": OrchestrationMode.SUPERVISOR, "elapsed_time": elapsed_time, "strategy": routing_decision.get("collaboration_strategy", "single"), - "sub_results": results + "sub_results": results, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } except Exception as e: @@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator: return { "message": result.get("response", ""), "conversation_id": result.get("conversation_id"), + "mode": OrchestrationMode.COLLABORATION, "elapsed_time": elapsed_time, "strategy": "collaboration", "active_agent": result.get("active_agent"), - "sub_results": result + "sub_results": result, + "usage": result.get("usage") } except Exception as e: diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index da984d16..c52814ed 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -1,5 +1,6 @@ """多 Agent 配置管理服务""" import uuid +import json from typing import Optional, List, Tuple, Any, Annotated from fastapi import Depends @@ -427,6 +428,23 @@ class MultiAgentService: memory=getattr(request, 'memory', True) # 记忆功能参数 ) + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=result.get("message", ""), + app_id=app_id, + user_id=request.user_id, + meta_data={ + "mode": result.get("mode"), + "elapsed_time": result.get("elapsed_time"), + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } + ) + return result async def run_stream( @@ -451,11 +469,14 @@ class MultiAgentService: raise ResourceNotFoundException("多 Agent 配置", str(app_id)) if not config.is_active: - raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) + raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND) # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) + full_content = "" + total_tokens = 0 + # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=request.message, @@ -468,7 +489,88 @@ class MultiAgentService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass + + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=full_content, + app_id=app_id, + user_id=request.user_id, + meta_data={ + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } + } + ) + + async def _save_conversation_message( + self, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str, + meta_data: dict, + app_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None + ) -> None: + """保存会话消息 + + Args: + conversation_id: 会话ID + user_message: 用户消息 + assistant_message: AI 回复消息 + meta_data: 元数据(包括 token 消耗) + app_id: 应用ID + user_id: 用户ID + """ + try: + from app.services.conversation_service import ConversationService + + conversation_service = ConversationService(self.db) + + conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=user_message + ) + conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=assistant_message, + meta_data=meta_data + ) + + logger.debug( + "保存多 Agent 会话消息", + extra={ + "conversation_id": conversation_id, + "user_message_length": len(user_message), + "assistant_message_length": len(assistant_message) + } + ) + + except Exception as e: + logger.warning("保存会话消息失败", extra={"error": str(e)}) # def add_sub_agent( # self, diff --git a/api/app/services/ontology_service.py b/api/app/services/ontology_service.py new file mode 100644 index 00000000..c832b0cc --- /dev/null +++ b/api/app/services/ontology_service.py @@ -0,0 +1,1162 @@ +"""本体提取服务层 + +本模块提供本体提取的业务逻辑封装,协调OntologyExtractor和OWLValidator。 +包括本体提取、OWL文件导出等功能。 + +Classes: + OntologyService: 本体提取服务类,封装业务逻辑 +""" + +import logging +import time +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session + +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.ontology_extraction import ( + OntologyExtractor, +) +from app.core.memory.utils.validation.owl_validator import OWLValidator + + +logger = logging.getLogger(__name__) + + +class OntologyService: + """本体提取服务层 + + 封装本体提取的业务逻辑,协调各个组件: + - OntologyExtractor: 执行LLM驱动的本体提取 + - OWLValidator: OWL语义验证 + + Attributes: + extractor: 本体提取器实例 + owl_validator: OWL验证器实例 + db: 数据库会话 + """ + + # 默认配置参数 + DEFAULT_MAX_CLASSES = 15 + DEFAULT_MIN_CLASSES = 5 + DEFAULT_MAX_DESCRIPTION_LENGTH = 500 + DEFAULT_LLM_TEMPERATURE = 0.3 + DEFAULT_LLM_MAX_TOKENS = 2000 + DEFAULT_LLM_TIMEOUT = 30.0 + DEFAULT_ENABLE_OWL_VALIDATION = True + + def __init__( + self, + llm_client: OpenAIClient, + db: Session + ): + """初始化本体提取服务 + + Args: + llm_client: OpenAI客户端实例 + db: SQLAlchemy数据库会话 + """ + self.extractor = OntologyExtractor(llm_client) + self.owl_validator = OWLValidator() + self.db = db + + # 初始化Repository + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.repositories.ontology_class_repository import OntologyClassRepository + + self.scene_repo = OntologySceneRepository(db) + self.class_repo = OntologyClassRepository(db) + + logger.info("OntologyService initialized") + + async def extract_ontology( + self, + scenario: str, + domain: Optional[str] = None, + scene_id: Optional[Any] = None, + workspace_id: Optional[Any] = None + ) -> OntologyExtractionResponse: + """执行本体提取 + + 使用默认配置参数调用OntologyExtractor执行提取。 + 提取结果仅返回给前端,不会自动保存到数据库。 + 前端需要调用 /class 接口来保存选中的类型。 + + Args: + scenario: 场景描述文本 + domain: 可选的领域提示 + scene_id: 可选的场景ID,用于权限验证(不再用于自动保存) + workspace_id: 可选的工作空间ID,用于权限验证 + + Returns: + OntologyExtractionResponse: 提取结果 + + Raises: + ValueError: 场景描述为空、场景不存在或无权限 + RuntimeError: 提取过程失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> response = await service.extract_ontology( + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... scene_id=scene_uuid, + ... workspace_id=workspace_uuid + ... ) + >>> len(response.classes) + 7 + """ + # 开始计时 + start_time = time.time() + + # 验证输入 + if not scenario or not scenario.strip(): + logger.error("Scenario description is empty") + raise ValueError("Scenario description cannot be empty") + + # 如果提供了scene_id,验证场景是否存在且有权限 + if scene_id and workspace_id: + logger.info(f"Validating scene access - scene_id={scene_id}, workspace_id={workspace_id}") + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + logger.info( + f"Starting ontology extraction service - " + f"scenario_length={len(scenario)}, " + f"domain={domain}, " + f"scene_id={scene_id}" + ) + + try: + # 调用提取器执行提取(使用默认配置) + logger.info("Calling OntologyExtractor with default config") + extraction_start_time = time.time() + + response = await self.extractor.extract_ontology_classes( + scenario=scenario, + domain=domain, + max_classes=self.DEFAULT_MAX_CLASSES, + min_classes=self.DEFAULT_MIN_CLASSES, + enable_owl_validation=self.DEFAULT_ENABLE_OWL_VALIDATION, + llm_temperature=self.DEFAULT_LLM_TEMPERATURE, + llm_max_tokens=self.DEFAULT_LLM_MAX_TOKENS, + max_description_length=self.DEFAULT_MAX_DESCRIPTION_LENGTH, + timeout=self.DEFAULT_LLM_TIMEOUT, + ) + + extraction_duration = time.time() - extraction_start_time + + # 检查是否成功提取到类 + if not response.classes: + logger.error("Ontology extraction failed: No classes extracted (structured output may have failed)") + raise RuntimeError("本体提取失败:结构化输出失败,未能提取到任何本体类") + + # 注释:提取结果仅返回给前端,不保存到数据库 + # 前端将从返回结果中选择需要的类型,然后调用 /class 接口创建 + logger.info( + f"Extraction completed. Classes will be saved to ontology_class " + f"via /class endpoint based on user selection" + ) + + total_duration = time.time() - start_time + + # 记录提取统计 + logger.info( + f"Ontology extraction service completed - " + f"extracted_classes={len(response.classes)}, " + f"domain={response.domain}, " + f"extraction_duration={extraction_duration:.2f}s, " + f"total_duration={total_duration:.2f}s" + ) + + return response + + except ValueError: + # 重新抛出验证错误 + total_duration = time.time() - start_time + logger.error( + f"Validation error after {total_duration:.2f}s", + exc_info=True + ) + raise + except Exception as e: + total_duration = time.time() - start_time + error_msg = f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + async def export_owl_file( + self, + classes: List[OntologyClass], + output_path: str, + format: str = "rdfxml", + ) -> str: + """导出OWL文件 + + 将提取的本体类导出为OWL文件,支持多种格式。 + + Args: + classes: 本体类列表 + output_path: 输出文件路径 + format: 导出格式,可选值: "rdfxml", "turtle", "ntriples" (默认: "rdfxml") + + Returns: + str: 导出的OWL文件内容 + + Raises: + ValueError: 类列表为空或格式不支持 + RuntimeError: 导出失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> owl_content = await service.export_owl_file( + ... classes=response.classes, + ... output_path="ontology.owl", + ... format="rdfxml" + ... ) + """ + # 验证输入 + if not classes: + logger.error("Classes list is empty") + raise ValueError("Classes list cannot be empty") + + valid_formats = ["rdfxml", "turtle", "ntriples"] + if format not in valid_formats: + error_msg = f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}" + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info( + f"Starting OWL export - " + f"classes_count={len(classes)}, " + f"output_path={output_path}, " + f"format={format}" + ) + + try: + # 步骤1: 验证本体类 + logger.debug("Validating ontology classes") + is_valid, errors, world = self.owl_validator.validate_ontology_classes( + classes=classes, + ) + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues during export: {errors}" + ) + # 继续导出,但记录警告 + + if not world: + error_msg = "Failed to create OWL world for export" + logger.error(error_msg) + raise RuntimeError(error_msg) + + # 步骤2: 导出OWL文件 + logger.info(f"Exporting to {format} format") + owl_content = self.owl_validator.export_to_owl( + world=world, + output_path=output_path, + format=format + ) + + logger.info( + f"OWL export completed - " + f"output_path={output_path}, " + f"content_length={len(owl_content)}" + ) + + return owl_content + + except Exception as e: + error_msg = f"OWL export failed: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + + # ==================== 本体场景管理方法 ==================== + + def create_scene( + self, + scene_name: str, + scene_description: Optional[str], + workspace_id: Any + ): + """创建本体场景 + + Args: + scene_name: 场景名称 + scene_description: 场景描述 + workspace_id: 所属工作空间ID + + Returns: + OntologyScene: 创建的场景对象 + + Raises: + ValueError: 场景名称为空 + RuntimeError: 创建失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.create_scene( + ... "医疗场景", + ... "用于医疗领域的本体建模", + ... workspace_id + ... ) + """ + # 验证输入 + if not scene_name or not scene_name.strip(): + logger.error("Scene name is empty") + raise ValueError("场景名称不能为空") + + logger.info( + f"Creating scene - " + f"name={scene_name}, workspace_id={workspace_id}" + ) + + try: + scene_data = { + "scene_name": scene_name.strip(), + "scene_description": scene_description + } + + scene = self.scene_repo.create(scene_data, workspace_id) + self.db.commit() + + logger.info(f"Scene created successfully: {scene.scene_id}") + + return scene + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to create scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def update_scene( + self, + scene_id: Any, + scene_name: Optional[str], + scene_description: Optional[str], + workspace_id: Any + ): + """更新本体场景 + + Args: + scene_id: 场景ID + scene_name: 场景名称(可选) + scene_description: 场景描述(可选) + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyScene: 更新后的场景对象 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 更新失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.update_scene( + ... scene_id, + ... "新名称", + ... "新描述", + ... workspace_id + ... ) + """ + logger.info(f"Updating scene: {scene_id}") + + try: + # 检查场景是否存在 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该场景") + + # 准备更新数据 + update_data = {} + if scene_name is not None: + if not scene_name.strip(): + raise ValueError("场景名称不能为空") + update_data["scene_name"] = scene_name.strip() + + if scene_description is not None: + update_data["scene_description"] = scene_description + + # 如果没有更新数据,直接返回 + if not update_data: + logger.info("No update data provided, returning existing scene") + return scene + + # 执行更新 + updated_scene = self.scene_repo.update(scene_id, update_data) + self.db.commit() + + logger.info(f"Scene updated successfully: {scene_id}") + + return updated_scene + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to update scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def delete_scene( + self, + scene_id: Any, + workspace_id: Any + ) -> bool: + """删除本体场景 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + bool: 删除成功返回True + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 删除失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> success = service.delete_scene(scene_id, workspace_id) + """ + logger.info(f"Deleting scene: {scene_id}") + + try: + # 检查场景是否存在 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该场景") + + # 执行删除 + success = self.scene_repo.delete(scene_id) + self.db.commit() + + logger.info(f"Scene deleted successfully: {scene_id}") + + return success + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to delete scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_scene_by_id( + self, + scene_id: Any, + workspace_id: Any + ): + """获取单个场景 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyScene]: 场景对象 + + Raises: + ValueError: 场景不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.get_scene_by_id(scene_id, workspace_id) + """ + logger.debug(f"Getting scene by ID: {scene_id}") + + try: + # 获取场景 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + return scene + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_scene_by_name( + self, + scene_name: str, + workspace_id: Any + ): + """根据场景名称获取场景(精确匹配) + + Args: + scene_name: 场景名称 + workspace_id: 工作空间ID + + Returns: + Optional[OntologyScene]: 场景对象 + + Raises: + ValueError: 场景不存在 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.get_scene_by_name("医疗场景", workspace_id) + """ + logger.debug(f"Getting scene by name: {scene_name}, workspace_id: {workspace_id}") + + try: + # 获取场景 + scene = self.scene_repo.get_by_name(scene_name, workspace_id) + if not scene: + logger.warning(f"Scene not found: {scene_name} in workspace {workspace_id}") + raise ValueError("场景不存在") + + return scene + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get scene by name: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def search_scenes_by_name( + self, + keyword: str, + workspace_id: Any + ) -> List: + """根据关键词模糊搜索场景 + + Args: + keyword: 搜索关键词 + workspace_id: 工作空间ID + + Returns: + List[OntologyScene]: 匹配的场景列表 + + Raises: + RuntimeError: 搜索失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scenes = service.search_scenes_by_name("医疗", workspace_id) + """ + logger.debug(f"Searching scenes by keyword: {keyword}, workspace_id: {workspace_id}") + + try: + scenes = self.scene_repo.search_by_name(keyword, workspace_id) + + logger.info( + f"Found {len(scenes)} scenes matching keyword '{keyword}' " + f"in workspace {workspace_id}" + ) + + return scenes + + except Exception as e: + error_msg = f"Failed to search scenes by keyword: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def list_scenes( + self, + workspace_id: Any, + page: Optional[int] = None, + page_size: Optional[int] = None + ) -> tuple: + """获取工作空间下的所有场景(支持分页) + + Args: + workspace_id: 工作空间ID + page: 页码(可选,从1开始) + page_size: 每页数量(可选) + + Returns: + tuple: (场景列表, 总数量) + + Raises: + RuntimeError: 查询失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scenes, total = service.list_scenes(workspace_id) + >>> scenes, total = service.list_scenes(workspace_id, page=1, page_size=10) + """ + logger.debug(f"Listing scenes for workspace: {workspace_id}, page={page}, page_size={page_size}") + + try: + scenes, total = self.scene_repo.get_by_workspace(workspace_id, page, page_size) + + logger.info(f"Found {len(scenes)} scenes (total: {total}) in workspace {workspace_id}") + + return scenes, total + + except Exception as e: + error_msg = f"Failed to list scenes: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + # ==================== 本体类型管理方法 ==================== + + def create_class( + self, + scene_id: Any, + class_name: str, + class_description: Optional[str], + workspace_id: Any + ): + """创建本体类型 + + Args: + scene_id: 所属场景ID + class_name: 类型名称 + class_description: 类型描述 + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyClass: 创建的类型对象 + + Raises: + ValueError: 类型名称为空、场景不存在或无权限 + RuntimeError: 创建失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.create_class( + ... scene_id, + ... "患者", + ... "医院患者信息", + ... workspace_id + ... ) + """ + # 验证输入 + if not class_name or not class_name.strip(): + logger.error("Class name is empty") + raise ValueError("类型名称不能为空") + + logger.info( + f"Creating class - " + f"name={class_name}, scene_id={scene_id}" + ) + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("所属场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + # 创建类型 + class_data = { + "class_name": class_name.strip(), + "class_description": class_description + } + + ontology_class = self.class_repo.create(class_data, scene_id) + self.db.commit() + + logger.info(f"Class created successfully: {ontology_class.class_id}") + + return ontology_class + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to create class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def create_classes_batch( + self, + scene_id: Any, + classes: List[Dict[str, Optional[str]]], + workspace_id: Any + ): + """批量创建本体类型 + + Args: + scene_id: 所属场景ID + classes: 类型列表,每个元素包含 class_name 和 class_description + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Tuple[List, List[str]]: (成功创建的类型列表, 错误信息列表) + + Raises: + ValueError: 场景不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes_data = [ + ... {"class_name": "患者", "class_description": "医院患者信息"}, + ... {"class_name": "医生", "class_description": "医院医生信息"} + ... ] + >>> created_classes, errors = service.create_classes_batch( + ... scene_id, + ... classes_data, + ... workspace_id + ... ) + """ + logger.info( + f"Batch creating classes - " + f"count={len(classes)}, scene_id={scene_id}" + ) + + # 检查场景是否存在且属于当前工作空间(只检查一次) + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("所属场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + created_classes = [] + errors = [] + + for idx, class_data in enumerate(classes): + class_name = class_data.get("class_name", "").strip() + class_description = class_data.get("class_description") + + if not class_name: + error_msg = f"第 {idx + 1} 个类型名称为空,已跳过" + logger.warning(error_msg) + errors.append(error_msg) + continue + + try: + # 创建类型(不需要再次检查权限) + create_data = { + "class_name": class_name, + "class_description": class_description + } + + ontology_class = self.class_repo.create(create_data, scene_id) + created_classes.append(ontology_class) + logger.info(f"Class created successfully: {class_name}") + + except Exception as e: + error_msg = f"创建类型 '{class_name}' 失败: {str(e)}" + logger.error(error_msg) + errors.append(error_msg) + + # 统一提交所有成功的创建 + try: + self.db.commit() + logger.info( + f"Batch creation completed - " + f"success={len(created_classes)}, failed={len(errors)}" + ) + except Exception as e: + self.db.rollback() + error_msg = f"批量创建提交失败: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + return created_classes, errors + + def update_class( + self, + class_id: Any, + class_name: Optional[str], + class_description: Optional[str], + workspace_id: Any + ): + """更新本体类型 + + Args: + class_id: 类型ID + class_name: 类型名称(可选) + class_description: 类型描述(可选) + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyClass: 更新后的类型对象 + + Raises: + ValueError: 类型不存在或无权限 + RuntimeError: 更新失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.update_class( + ... class_id, + ... "新名称", + ... "新描述", + ... workspace_id + ... ) + """ + logger.info(f"Updating class: {class_id}") + + try: + # 检查类型是否存在 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该类型") + + # 准备更新数据 + update_data = {} + if class_name is not None: + if not class_name.strip(): + raise ValueError("类型名称不能为空") + update_data["class_name"] = class_name.strip() + + if class_description is not None: + update_data["class_description"] = class_description + + # 如果没有更新数据,直接返回 + if not update_data: + logger.info("No update data provided, returning existing class") + return ontology_class + + # 执行更新 + updated_class = self.class_repo.update(class_id, update_data) + self.db.commit() + + logger.info(f"Class updated successfully: {class_id}") + + return updated_class + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to update class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def delete_class( + self, + class_id: Any, + workspace_id: Any + ) -> bool: + """删除本体类型 + + Args: + class_id: 类型ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + bool: 删除成功返回True + + Raises: + ValueError: 类型不存在或无权限 + RuntimeError: 删除失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> success = service.delete_class(class_id, workspace_id) + """ + logger.info(f"Deleting class: {class_id}") + + try: + # 检查类型是否存在 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该类型") + + # 执行删除 + success = self.class_repo.delete(class_id) + self.db.commit() + + logger.info(f"Class deleted successfully: {class_id}") + + return success + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to delete class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_class_by_id( + self, + class_id: Any, + workspace_id: Any + ): + """获取单个类型 + + Args: + class_id: 类型ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyClass]: 类型对象 + + Raises: + ValueError: 类型不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.get_class_by_id(class_id, workspace_id) + """ + logger.debug(f"Getting class by ID: {class_id}") + + try: + # 获取类型 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该类型") + + return ontology_class + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_class_by_name( + self, + class_name: str, + scene_id: Any, + workspace_id: Any + ): + """根据类型名称获取类型(精确匹配) + + Args: + class_name: 类型名称 + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyClass]: 类型对象 + + Raises: + ValueError: 类型不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.get_class_by_name("患者", scene_id, workspace_id) + """ + logger.debug(f"Getting class by name: {class_name}, scene_id: {scene_id}") + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + # 获取类型 + ontology_class = self.class_repo.get_by_name(class_name, scene_id) + if not ontology_class: + logger.warning(f"Class not found: {class_name} in scene {scene_id}") + raise ValueError("类型不存在") + + return ontology_class + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get class by name: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def search_classes_by_name( + self, + keyword: str, + scene_id: Any, + workspace_id: Any + ) -> List: + """根据关键词模糊搜索类型 + + Args: + keyword: 搜索关键词 + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + List[OntologyClass]: 匹配的类型列表 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 搜索失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes = service.search_classes_by_name("患者", scene_id, workspace_id) + """ + logger.debug( + f"Searching classes by keyword: {keyword}, " + f"scene_id: {scene_id}, workspace_id: {workspace_id}" + ) + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + # 搜索类型 + classes = self.class_repo.search_by_name(keyword, scene_id) + + logger.info( + f"Found {len(classes)} classes matching keyword '{keyword}' " + f"in scene {scene_id}" + ) + + return classes + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to search classes by keyword: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def list_classes_by_scene( + self, + scene_id: Any, + workspace_id: Any + ) -> List: + """获取场景下的所有类型 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + List[OntologyClass]: 类型列表 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 查询失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes = service.list_classes_by_scene(scene_id, workspace_id) + """ + logger.debug(f"Listing classes for scene: {scene_id}") + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景的类型") + + # 获取类型列表 + classes = self.class_repo.get_by_scene(scene_id) + + logger.info(f"Found {len(classes)} classes in scene {scene_id}") + + return classes + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to list classes: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 9e447214..2c0b57ac 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -1,3 +1,4 @@ +import os import re import uuid from typing import Any, AsyncGenerator @@ -18,7 +19,8 @@ from app.models.prompt_optimizer_model import ( ) from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository from app.repositories.prompt_optimizer_repository import ( - PromptOptimizerSessionRepository + PromptOptimizerSessionRepository, + PromptReleaseRepository ) from app.schemas.prompt_optimizer_schema import OptimizePromptResult @@ -28,6 +30,8 @@ logger = get_business_logger() class PromptOptimizerService: def __init__(self, db: Session): self.db = db + self.optim_repo = PromptOptimizerSessionRepository(self.db) + self.release_repo = PromptReleaseRepository(self.db) def get_model_config( self, @@ -78,10 +82,12 @@ class PromptOptimizerService: Returns: PromptOptimzerSession: The newly created prompt optimization session. """ - session = PromptOptimizerSessionRepository(self.db).create_session( + session = self.optim_repo.create_session( tenant_id=tenant_id, user_id=user_id ) + self.db.commit() + self.db.refresh(session) return session def get_session_message_history( @@ -106,7 +112,7 @@ class PromptOptimizerService: - role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'. - content (str): The content of the message. """ - history = PromptOptimizerSessionRepository(self.db).get_session_history( + history = self.optim_repo.get_session_history( session_id=session_id, user_id=user_id ) @@ -177,11 +183,12 @@ class PromptOptimizerService: base_url=api_config.api_base ), type=ModelType(model_config.type)) try: - with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render() - with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: + with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: opt_user_prompt = f.read() except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) @@ -296,4 +303,165 @@ class PromptOptimizerService: role=role, content=content ) + self.db.commit() + self.db.refresh(message) return message + + def save_prompt( + self, + tenant_id: uuid.UUID, + session_id: uuid.UUID, + title: str, + prompt: str + ) -> dict: + """ + Create and save a new prompt release for a given session. + + Args: + tenant_id (uuid.UUID): The ID of the tenant owning the prompt. + session_id (uuid.UUID): The ID of the session to associate with this prompt. + title (str): The title of the prompt release. + prompt (str): The content of the prompt. + + Returns: + dict: A dictionary containing: + - id (UUID): The unique ID of the created prompt release. + - session_id (UUID): The session ID linked to the release. + - title (str): The title of the prompt. + - prompt (str): The prompt content. + - created_at (int): Timestamp (in milliseconds) of when the prompt was created. + + Raises: + BusinessException: If a prompt release already exists for the given session. + """ + session = self.optim_repo.get_session_by_id(session_id) + if session is None or session.tenant_id != tenant_id: + raise BusinessException( + "Session does not exist or the current user has no access", + BizCode.BAD_REQUEST + ) + + if self.release_repo.get_prompt_by_session_id(session_id): + raise BusinessException( + "A release already exists for the current session", + BizCode.BAD_REQUEST + ) + + prompt_obj = self.release_repo.create_prompt_release( + tenant_id=tenant_id, + title=title, + session_id=session_id, + prompt=prompt + ) + self.db.commit() + self.db.refresh(prompt_obj) + return { + "id": prompt_obj.id, + "session_id": prompt_obj.session_id, + "title": prompt_obj.title, + "prompt": prompt_obj.prompt, + "created_at": int(prompt_obj.created_at.timestamp() * 1000) + } + + def delete_prompt( + self, + tenant_id: uuid.UUID, + prompt_id: uuid.UUID + ) -> None: + """ + Soft delete a prompt release by prompt_id. + + Args: + tenant_id (uuid.UUID): Tenant identifier. + prompt_id (uuid.UUID): Prompt identifier. + + Raises: + BusinessException: If the prompt does not exist or already deleted. + """ + prompt_obj = self.release_repo.get_prompt_by_id(prompt_id) + if not prompt_obj or prompt_obj.is_delete: + raise BusinessException( + "Prompt does not exist or has already been deleted", + BizCode.NOT_FOUND + ) + + if prompt_obj.tenant_id != tenant_id: + raise BusinessException( + "No permission to delete this prompt", + BizCode.FORBIDDEN + ) + + self.release_repo.soft_delete_prompt(prompt_obj) + self.db.commit() + logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}") + + def get_release_list( + self, + tenant_id: uuid.UUID, + page: int, + page_size: int, + filter_keyword: str | None = None + ) -> dict[str, int | list[Any]]: + """ + Get paginated list of prompt releases with optional filter. + + Args: + tenant_id (uuid.UUID): Tenant identifier. + page (int): Page number (starting from 1). + page_size (int): Number of items per page. + filter_keyword (str | None): Optional keyword to filter by title. + + Returns: + dict: Contains total count, pagination info, and list of releases. + """ + offset = (page - 1) * page_size + + # Get total count and releases based on filter + if filter_keyword: + total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword) + releases = self.release_repo.search_prompts_paginated( + tenant_id=tenant_id, + keyword=filter_keyword, + offset=offset, + limit=page_size + ) + else: + total = self.release_repo.count_prompts(tenant_id) + releases = self.release_repo.get_prompts_paginated( + tenant_id=tenant_id, + offset=offset, + limit=page_size + ) + + items = [] + for release in releases: + # Get first user message from session + first_message = self.optim_repo.get_first_user_message( + session_id=release.session_id + ) + + items.append({ + "id": release.id, + "title": release.title, + "prompt": release.prompt, + "created_at": int(release.created_at.timestamp() * 1000), + "first_message": first_message + }) + + log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}" + if filter_keyword: + log_msg += f", filter='{filter_keyword}'" + logger.info(log_msg) + + result = { + "page": { + "total": total, + "page": page, + "page_size": page_size, + "hasnext": page * page_size < total + }, + "keyword": filter_keyword, + "items": items + } + + return result diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 1d012088..a92c2649 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -282,7 +282,14 @@ class SharedChatService: self.conversation_service.save_conversation_messages( conversation_id=conversation.id, user_message=message, - assistant_message=result["content"] + assistant_message=result["content"], + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) # self.conversation_service.add_message( # conversation_id=conversation.id, @@ -469,6 +476,7 @@ class SharedChatService: # 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -479,9 +487,12 @@ class SharedChatService: config_id=config_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" elapsed_time = time.time() - start_time @@ -498,7 +509,7 @@ class SharedChatService: content=full_content, meta_data={ "model": api_key_obj.model_name, - "usage": {} + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} } ) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 3a90a821..d5f03e85 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.conversation_repository import ConversationRepository from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService @@ -1508,7 +1509,6 @@ async def analytics_graph_data( user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -1562,21 +1562,11 @@ async def analytics_graph_data( } else: # 查询所有节点 - node_query = """ - MATCH (n) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) as id, - labels(n)[0] as label, - properties(n) as properties - LIMIT $limit - """ + node_query=Graph_Node_query node_params = { "end_user_id": end_user_id, "limit": limit } - - # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1587,9 +1577,9 @@ async def analytics_graph_data( for record in node_results: node_id = record["id"] - node_label = record["label"] + node_labels = record.get("labels", []) + node_label = node_labels[0] if node_labels else "Unknown" node_props = record["properties"] - # 根据节点类型提取需要的属性字段 filtered_props = await _extract_node_properties(node_label, node_props,node_id) diff --git a/api/app/tasks.py b/api/app/tasks.py index cdd7945e..db332816 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -774,7 +774,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } -@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True) +@celery_app.task( + name="app.tasks.regenerate_memory_cache", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) def regenerate_memory_cache(self) -> Dict[str, Any]: """定时任务:为所有用户重新生成记忆洞察和用户摘要缓存 @@ -966,7 +974,16 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } -@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) + +@celery_app.task( + name="app.tasks.workspace_reflection_task", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=300, + soft_time_limit=240, +) def workspace_reflection_task(self) -> Dict[str, Any]: """定时任务:每30秒运行工作空间反思功能 @@ -1049,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: + db.rollback() # Rollback failed transaction to allow next query api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), @@ -1111,7 +1129,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]: -@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True) + +@celery_app.task( + name="app.tasks.run_forgetting_cycle_task", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=7200, + soft_time_limit=7000, +) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 @@ -1178,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di return result finally: loop.close() + + +# ============================================================================= +# Long-term Memory Storage Tasks (Batched Write Strategies) +# ============================================================================= + +@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True) +def long_term_storage_window_task( + self, + end_user_id: str, + langchain_messages: List[Dict[str, Any]], + config_id: str, + scope: int = 6 +) -> Dict[str, Any]: + """Celery task for window-based long-term memory storage. + + Accumulates messages in Redis buffer until window size (scope) is reached, + then writes batched messages to Neo4j. + + Args: + end_user_id: End user identifier + langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] + config_id: Memory configuration ID + scope: Window size (number of messages before triggering write) + + Returns: + Dict containing task status and metadata + """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}") + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + from app.services.memory_config_service import MemoryConfigService + + db = next(get_db()) + try: + # Save to Redis buffer first + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + + # Load memory config + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="LongTermStorageTask" + ) + + # Execute window-based dialogue storage + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + + return {"status": "SUCCESS", "strategy": "window", "scope": scope} + finally: + db.close() + + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + + logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s") + + return { + **result, + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True) + + return { + "status": "FAILURE", + "strategy": "window", + "error": str(e), + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True) +# def long_term_storage_time_task( +# self, +# end_user_id: str, +# config_id: str, +# time_window: int = 5 +# ) -> Dict[str, Any]: +# """Celery task for time-based long-term memory storage. + +# Retrieves recent sessions from Redis within time window and writes to Neo4j. + +# Args: +# end_user_id: End user identifier +# config_id: Memory configuration ID +# time_window: Time window in minutes for retrieving recent sessions + +# Returns: +# Dict containing task status and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute time-based storage +# await memory_long_term_storage(end_user_id, memory_config, time_window) + +# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window} +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "time", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True) +# def long_term_storage_aggregate_task( +# self, +# end_user_id: str, +# langchain_messages: List[Dict[str, Any]], +# config_id: str +# ) -> Dict[str, Any]: +# """Celery task for aggregate-based long-term memory storage. + +# Uses LLM to determine if new messages describe the same event as history. +# Only writes to Neo4j if messages represent new information (not duplicates). + +# Args: +# end_user_id: End user identifier +# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] +# config_id: Memory configuration ID + +# Returns: +# Dict containing task status, is_same_event flag, and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment +# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format +# from app.core.memory.agent.utils.redis_tool import write_store +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Save to Redis buffer first +# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute aggregate judgment +# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config) + +# return { +# "status": "SUCCESS", +# "strategy": "aggregate", +# "is_same_event": result.get("is_same_event", False), +# "wrote_to_neo4j": not result.get("is_same_event", False) +# } +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "aggregate", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index cc67afd2..55cfe8a3 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports. """ from uuid import UUID from sqlalchemy.orm import Session +import uuid as uuid_module -def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: +def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: """ - 解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID + 解析 config_id,支持 UUID、UUID字符串、整数等多种格式 Args: - config_id: 配置ID(UUID 或整数) + config_id: 配置ID(UUID、UUID字符串 或 整数) db: 数据库会话 Returns: UUID: 解析后的配置ID Raises: - ValueError: 当找不到对应的配置时 + ValueError: 当找不到对应的配置时或格式无效时 """ - from app.models.memory_config_model import MemoryConfig - if isinstance(config_id, UUID): + + # 1. 如果已经是 UUID 类型,直接返回 + if isinstance(config_id, UUID): return config_id - if isinstance(config_id, str) and len(config_id)<=6: - memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == int(config_id) - ).first() - print(memory_config) - if not memory_config: - raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") - return memory_config.config_id + + # 2. 如果是字符串类型 + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36) + try: + return uuid_module.UUID(config_id_stripped) + except ValueError: + pass + + # 2.2 尝试解析为整数(用于查询 config_id_old) + try: + old_id = int(config_id_stripped) + if old_id > 0: + memory_config = db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == old_id + ).first() + if not memory_config: + raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") + return memory_config.config_id + except ValueError: + pass + + # 2.3 无法解析的字符串格式 + raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") + + # 3. 如果是整数类型,通过 config_id_old 查找 if isinstance(config_id, int): + if config_id <= 0: + raise ValueError(f"config_id 必须是正整数: {config_id}") + memory_config = db.query(MemoryConfig).filter( MemoryConfig.config_id_old == config_id ).first() if not memory_config: - raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id - return config_id + # 4. 不支持的类型 + raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}") diff --git a/api/app/version_info.json b/api/app/version_info.json index 86a5e33e..e82243a4 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,32 @@ { + "v0.2.2": { + "introduction": { + "codeName": "淬锋(Temper)", + "releaseDate": "2026-1-31", + "upgradePosition": "本次发布聚焦平台稳定性和性能优化。正如\"淬锋\"之名——千锤百炼,淬火成锋,我们通过严格测试和修复打磨系统品质。引入 Agent 工作流的代码执行能力、改进模型并发管理,并修复了记忆系统的多个关键问题。", + "coreUpgrades": [ + "1. Agent平台增强
* 模型并发管理:优化模型广场的并发请求处理和资源分配能力。", + "2. 记忆系统优化
* Celery 队列修复:解决任务队列问题,提升异步记忆处理的可靠性
* 记忆 Agent 优化:提升记忆 Agent 的性能和效率
* 接口响应速度优化:优化记忆接口响应时间,加快操作速度。", + "3. 情绪记忆与识别升级
* 情绪记忆角色识别修复:解决情绪记忆上下文中的角色/人物识别问题
* 角色识别增强:提升对话记忆中的角色/人物识别准确性。", + "
", + "MemoryBear 持续致力于为 AI 应用提供类人记忆能力。本次以稳定性为核心的发布,进一步夯实了「感知→精炼→关联→遗忘」范式的基础。", + "未来版本将在此坚实基础上,扩展 Agent 能力并深化记忆智能特性。" + ] + }, + "introduction_en": { + "codeName": "Temper (淬锋)", + "releaseDate": "2026-1-31", + "upgradePosition": "This release focuses on platform stability and performance optimization — true to its codename \"淬锋\" (tempered blade), we've refined the system through rigorous testing and fixes. Introducing Python code execution for Agent workflows, improved model concurrency management, and critical fixes across the memory system.", + "coreUpgrades": [ + "1. Agent Platform Enhancements
* Model Concurrency Management: Enhanced Model Plaza with improved concurrent model request handling and resource allocation.", + "2. Memory System Improvements
* Celery Queue Fix: Resolved task queue issues for more reliable asynchronous memory processing
* Memory Agent Optimization: Improved memory Agent performance and efficiency
* API Response Speed: Optimized memory interface response times for faster operations.", + "3. Emotional Memory & Recognition Upgrades
* Emotion Memory Role Recognition Fix: Resolved issues with role/character identification in emotional memory contexts
* Role Recognition Enhancement: Improved character/role identification accuracy in conversation memory.", + "
", + "MemoryBear continues advancing toward human-like memory capabilities for AI applications. This stability-focused release strengthens the foundation for our Perception → Refinement → Association → Forgetting paradigm.", + "Future releases will build on this solid base with expanded Agent capabilities and deeper memory intelligence features." + ] + } + }, "v0.2.1": { "introduction": { "codeName": "启知", diff --git a/api/docker-compose.yml b/api/docker-compose.yml index f30220cb..69763de2 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -19,6 +19,7 @@ services: depends_on: - worker-memory - worker-document + - worker-periodic # Memory worker - Memory read/write tasks (threads pool for asyncio) worker-memory: @@ -48,6 +49,20 @@ services: networks: - celery + # Periodic worker - Scheduled/beat tasks (prefork, low concurrency) + worker-periodic: + image: redbear-mem-open:latest + container_name: worker-periodic + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=2 --queues=periodic_tasks --max-tasks-per-child=50 -n periodic_worker@%h + restart: unless-stopped + networks: + - celery + # Celery Beat - scheduler beat: image: redbear-mem-open:latest @@ -69,7 +84,7 @@ services: container_name: sandbox ports: - "8194" - command: /code/.venv/bin/python main.py + command: /code/.venv/bin/uvicorn main:app --host 0.0.0.0 --port 8194 --log-level debug restart: unless-stopped networks: - sandbox diff --git a/api/env.example b/api/env.example index 274049b9..98c96edc 100644 --- a/api/env.example +++ b/api/env.example @@ -1,4 +1,9 @@ +# Language Configuration +# Supported values: "zh" (Chinese), "en" (English) +# This controls the language used for memory summary titles and other generated content +DEFAULT_LANGUAGE=zh + # Neo4j Configuration (记忆系统数据库) NEO4J_URI= NEO4J_USERNAME= diff --git a/api/migrations/versions/550c10595967_202601301521.py b/api/migrations/versions/550c10595967_202601301521.py new file mode 100644 index 00000000..b2f531db --- /dev/null +++ b/api/migrations/versions/550c10595967_202601301521.py @@ -0,0 +1,78 @@ +"""202601301521 + +Revision ID: 550c10595967 +Revises: 5de9b1e28509 +Create Date: 2026-01-30 15:24:34.647440 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '550c10595967' +down_revision: Union[str, None] = '5de9b1e28509' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('ontology_scene', + sa.Column('scene_id', sa.UUID(), nullable=False, comment='场景ID'), + sa.Column('scene_name', sa.String(length=200), nullable=False, comment='场景名称'), + sa.Column('scene_description', sa.Text(), nullable=True, comment='场景描述'), + sa.Column('workspace_id', sa.UUID(), nullable=False, comment='所属工作空间ID'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.ForeignKeyConstraint(['workspace_id'], ['workspaces.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('scene_id'), + sa.UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name') + ) + op.create_index(op.f('ix_ontology_scene_scene_id'), 'ontology_scene', ['scene_id'], unique=False) + op.create_index(op.f('ix_ontology_scene_workspace_id'), 'ontology_scene', ['workspace_id'], unique=False) + op.create_table('ontology_class', + sa.Column('class_id', sa.UUID(), nullable=False, comment='类型ID'), + sa.Column('class_name', sa.String(length=200), nullable=False, comment='类型名称'), + sa.Column('class_description', sa.Text(), nullable=True, comment='类型描述'), + sa.Column('scene_id', sa.UUID(), nullable=False, comment='所属场景ID'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.ForeignKeyConstraint(['scene_id'], ['ontology_scene.scene_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('class_id') + ) + op.create_index(op.f('ix_ontology_class_class_id'), 'ontology_class', ['class_id'], unique=False) + op.create_index(op.f('ix_ontology_class_scene_id'), 'ontology_class', ['scene_id'], unique=False) + op.create_table('prompt_history', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'), + sa.Column('session_id', sa.UUID(), nullable=False, comment='Session ID'), + sa.Column('title', sa.String(), nullable=False, comment='Title'), + sa.Column('prompt', sa.Text(), nullable=False, comment='Prompt'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'), + sa.Column('is_delete', sa.Boolean(), nullable=True, comment='Delete'), + sa.ForeignKeyConstraint(['session_id'], ['prompt_opt_session_list.id'], ), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_prompt_history_created_at'), 'prompt_history', ['created_at'], unique=False) + op.create_index(op.f('ix_prompt_history_id'), 'prompt_history', ['id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_index(op.f('ix_prompt_history_id'), table_name='prompt_history') + op.drop_index(op.f('ix_prompt_history_created_at'), table_name='prompt_history') + op.drop_table('prompt_history') + op.drop_index(op.f('ix_ontology_class_scene_id'), table_name='ontology_class') + op.drop_index(op.f('ix_ontology_class_class_id'), table_name='ontology_class') + op.drop_table('ontology_class') + op.drop_index(op.f('ix_ontology_scene_workspace_id'), table_name='ontology_scene') + op.drop_index(op.f('ix_ontology_scene_scene_id'), table_name='ontology_scene') + op.drop_table('ontology_scene') + # ### end Alembic commands ### diff --git a/api/migrations/versions/9def72f79398_202601301850.py b/api/migrations/versions/9def72f79398_202601301850.py new file mode 100644 index 00000000..303a1578 --- /dev/null +++ b/api/migrations/versions/9def72f79398_202601301850.py @@ -0,0 +1,30 @@ +"""202601301850 + +Revision ID: 9def72f79398 +Revises: 550c10595967 +Create Date: 2026-01-30 18:51:18.290796 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '9def72f79398' +down_revision: Union[str, None] = '550c10595967' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('memory_config', sa.Column('scene_id', sa.UUID(), nullable=True, comment='本体场景ID,关联ontology_scene表')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('memory_config', 'scene_id') + # ### end Alembic commands ### diff --git a/api/pyproject.toml b/api/pyproject.toml index 29597409..6d23a3b9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -140,6 +140,7 @@ dependencies = [ "oss2>=2.19.1", "flower>=2.0.1", "aiofiles>=23.0.0", + "owlready2>=0.46", ] [tool.pytest.ini_options] diff --git a/sandbox/Dockerfile b/sandbox/Dockerfile index 677b991c..e34b88dd 100644 --- a/sandbox/Dockerfile +++ b/sandbox/Dockerfile @@ -1,9 +1,10 @@ FROM python:3.12-slim USER root WORKDIR /code -LABEL authors="Eterntiy" -ARG NEED_MIRROR=0 +ARG NEED_MIRROR=1 +ENV DEBIAN_FRONTEND=noninteractive + RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ if [ "$NEED_MIRROR" == "1" ]; then \ @@ -17,11 +18,14 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ apt --no-install-recommends install -y ca-certificates && \ apt update && \ apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ + apt install -y nodejs npm && \ apt-get install -y --no-install-recommends tzdata libseccomp2 libseccomp-dev && \ ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ echo "Asia/Shanghai" > /etc/timezone && \ apt install -y cargo +ENV PYTHONDONTWRITEBYTECODE=1 + COPY ./app /code/app COPY ./dependencies /code/dependencies COPY ./lib /code/lib @@ -33,10 +37,15 @@ COPY ./requirements.txt /code/requirements.txt RUN python -m venv .venv RUN .venv/bin/python3 -m pip install -r requirements.txt -RUN cargo build --release --manifest-path lib/seccomp_python/Cargo.toml +RUN npm install --prefix=/code/dependencies/nodejs koffi -HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ +RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features python3 +RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libpython.so +RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features nodejs +RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libnodejs.so + +HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \ CMD curl 127.0.0.1:8194/health -CMD [".venv/bin/python3", "main.py"] \ No newline at end of file +CMD [".venv/bin/uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8194", "--log-level", "debug"] \ No newline at end of file diff --git a/sandbox/app/__init__.py b/sandbox/app/__init__.py new file mode 100644 index 00000000..1b201ce5 --- /dev/null +++ b/sandbox/app/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/1/29 14:33 diff --git a/sandbox/app/config.py b/sandbox/app/config.py index 3fa4cab5..e4930465 100644 --- a/sandbox/app/config.py +++ b/sandbox/app/config.py @@ -4,9 +4,6 @@ from typing import List, Optional from pydantic import BaseModel, Field import yaml -SANDBOX_USER_ID = 1000 -SANDBOX_GROUP_ID = 1000 - DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [ "/usr/local/lib/python3.12", "/usr/lib/python3", @@ -15,13 +12,18 @@ DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [ "/etc/nsswitch.conf", "/etc/hosts", "/etc/resolv.conf", - "/run/systemd/resolve/stub-resolv.conf", - "/run/resolvconf/resolv.conf", "/etc/localtime", "/usr/share/zoneinfo", "/etc/timezone", ] +DEFAULT_NODEJS_LIB_REQUIREMENTS = [ + "/etc/ssl/certs/ca-certificates.crt", + "/etc/nsswitch.conf", + "/etc/resolv.conf", + "/etc/hosts", +] + class AppConfig(BaseModel): """Application configuration""" @@ -43,83 +45,77 @@ class Config(BaseModel): max_workers: int = 4 max_requests: int = 50 worker_timeout: int = 30 - nodejs_path: str = "node" + enable_network: bool = True enable_preload: bool = False python_path: str = "" python_lib_paths: list = Field(default=DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD) python_deps_update_interval: str = "30m" + + nodejs_path: str = "" + nodejs_lib_paths: list = Field(default=DEFAULT_NODEJS_LIB_REQUIREMENTS) + allowed_syscalls: List[int] = Field(default_factory=list) proxy: ProxyConfig = Field(default_factory=ProxyConfig) + sandbox_user: str = "sandbox" + sandbox_uid: int = 65537 + sandbox_gid: int = 0 + + def set_sandbox_gid(self, gid: int): + """Update sandbox GID dynamically""" + self.sandbox_gid = gid + + def override_with_env(self): + """Override configuration with environment variables""" + env_map = { + "DEBUG": ("app.debug", lambda v: v.lower() in ("true", "1", "yes")), + "MAX_WORKERS": ("max_workers", int), + "MAX_REQUESTS": ("max_requests", int), + "SANDBOX_PORT": ("app.port", int), + "WORKER_TIMEOUT": ("worker_timeout", int), + "API_KEY": ("app.key", str), + "NODEJS_PATH": ("nodejs_path", str), + "ENABLE_NETWORK": ("enable_network", lambda v: v.lower() in ("true", "1", "yes")), + "ENABLE_PRELOAD": ("enable_preload", lambda v: v.lower() in ("true", "1", "yes")), + "ALLOWED_SYSCALLS": ("allowed_syscalls", lambda v: [int(x) for x in v.split(",")]), + "SOCKS5_PROXY": ("proxy.socks5", str), + "HTTP_PROXY": ("proxy.http", str), + "HTTPS_PROXY": ("proxy.https", str), + "PYTHON_PATH": ("python_path", str), + "PYTHON_LIB_PATH": ("python_lib_paths", lambda v: v.split(",")), + "PYTHON_DEPS_UPDATE_INTERVAL": ("python_deps_update_interval", str), + "NODEJS_LIB_PATH": ("nodejs_lib_paths", lambda v: v.split(",")), + } + + for env_var, (attr_path, cast) in env_map.items(): + value = os.getenv(env_var) + if value is not None: + # Support nested attributes like 'app.debug' + parts = attr_path.split(".") + obj = self + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], cast(value)) + # Global configuration instance _config: Optional[Config] = None -def load_config(config_path: str) -> Config: - """Load configuration from YAML file""" +def load_config(config_path: str = "config.yaml") -> Config: + """Load configuration from YAML file and override with env variables""" global _config - - # Load from file if os.path.exists(config_path): with open(config_path, 'r') as f: - data = yaml.safe_load(f) + data = yaml.safe_load(f) or {} _config = Config(**data) else: _config = Config() - # Override with environment variables - if os.getenv("DEBUG"): - _config.app.debug = os.getenv("DEBUG").lower() in ("true", "1", "yes") - - if os.getenv("MAX_WORKERS"): - _config.max_workers = int(os.getenv("MAX_WORKERS")) - - if os.getenv("MAX_REQUESTS"): - _config.max_requests = int(os.getenv("MAX_REQUESTS")) - - if os.getenv("SANDBOX_PORT"): - _config.app.port = int(os.getenv("SANDBOX_PORT")) - - if os.getenv("WORKER_TIMEOUT"): - _config.worker_timeout = int(os.getenv("WORKER_TIMEOUT")) - - if os.getenv("API_KEY"): - _config.app.key = os.getenv("API_KEY") - - if os.getenv("NODEJS_PATH"): - _config.nodejs_path = os.getenv("NODEJS_PATH") - - if os.getenv("ENABLE_NETWORK"): - _config.enable_network = os.getenv("ENABLE_NETWORK").lower() in ("true", "1", "yes") - - if os.getenv("ENABLE_PRELOAD"): - _config.enable_preload = os.getenv("ENABLE_PRELOAD").lower() in ("true", "1", "yes") - - if os.getenv("ALLOWED_SYSCALLS"): - _config.allowed_syscalls = [int(x) for x in os.getenv("ALLOWED_SYSCALLS").split(",")] - - if os.getenv("SOCKS5_PROXY"): - _config.proxy.socks5 = os.getenv("SOCKS5_PROXY") - - if os.getenv("HTTP_PROXY"): - _config.proxy.http = os.getenv("HTTP_PROXY") - - if os.getenv("HTTPS_PROXY"): - _config.proxy.https = os.getenv("HTTPS_PROXY") - - # python - if os.getenv("PYTHON_PATH"): - _config.python_path = os.getenv("PYTHON_PATH") - - if os.getenv("PYTHON_LIB_PATH"): - _config.python_lib_paths = os.getenv("PYTHON_LIB_PATH").split(',') - - if os.getenv("PYTHON_DEPS_UPDATE_INTERVAL"): - _config.python_deps_update_interval = os.getenv("PYTHON_DEPS_UPDATE_INTERVAL") - + # Override from environment + _config.override_with_env() return _config diff --git a/sandbox/app/controllers/health_controller.py b/sandbox/app/controllers/health_controller.py index 4d872e58..882578ec 100644 --- a/sandbox/app/controllers/health_controller.py +++ b/sandbox/app/controllers/health_controller.py @@ -9,4 +9,4 @@ router = APIRouter() @router.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" - return HealthResponse(status="healthy", version="2.0.0") + return HealthResponse(status="healthy", version="0.1.0") diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py index 1a713f52..f9bc3fc0 100644 --- a/sandbox/app/controllers/sandbox_controller.py +++ b/sandbox/app/controllers/sandbox_controller.py @@ -2,13 +2,15 @@ from fastapi import APIRouter, Depends from app.middleware.auth import verify_api_key -from app.middleware.concurrency import check_max_requests, acquire_worker +from app.middleware.concurrency import concurrency_guard + from app.models import ( RunCodeRequest, ApiResponse, UpdateDependencyRequest, error_response ) +from app.services.nodejs_service import run_nodejs_code from app.services.python_service import ( run_python_code, list_python_dependencies, @@ -25,16 +27,14 @@ router = APIRouter( @router.post( "/run", response_model=ApiResponse, - dependencies=[Depends(check_max_requests), - Depends(acquire_worker)] + dependencies=[Depends(concurrency_guard)] ) async def run_code(request: RunCodeRequest): """Execute code in sandbox""" if request.language == "python3": return await run_python_code(request.code, request.preload, request.options) - elif request.language == "nodejs": - # TODO - return error_response(-400, "TODO") + elif request.language == "javascript": + return await run_nodejs_code(request.code, request.preload, request.options) else: return error_response(-400, "unsupported language") @@ -55,5 +55,3 @@ async def update_dependencies(request: UpdateDependencyRequest): return await update_python_dependencies() else: return error_response(-400, "unsupported language") - - diff --git a/sandbox/app/core/runners/__init__.py b/sandbox/app/core/runners/__init__.py index 96c5e380..b8021009 100644 --- a/sandbox/app/core/runners/__init__.py +++ b/sandbox/app/core/runners/__init__.py @@ -1 +1,40 @@ """Code runners package""" +import pwd +import subprocess + +from app.config import get_config +from app.logger import get_logger + +logger = get_logger() + + +def init_sandbox_user(): + config = get_config() + sandbox_user = config.sandbox_user + sandbox_uid = config.sandbox_uid + try: + pwd.getpwnam(sandbox_user) + logger.info(f"User '{sandbox_user}' already exists") + except KeyError: + try: + subprocess.run( + ["useradd", "-u", str(sandbox_uid), sandbox_user], + check=True, + capture_output=True, + text=True + ) + logger.info(f"Created user '{sandbox_user}' with UID {sandbox_uid}") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to create user: {e.stderr}") + raise RuntimeError(f"Failed to create user '{sandbox_user}': {e.stderr}") from e + + try: + user_info = pwd.getpwnam(sandbox_user) + config.set_sandbox_gid(user_info.pw_gid) + logger.info(f"Sandbox user GID: {config.sandbox_gid}") + except KeyError as e: + logger.error(f"Failed to get GID for user '{sandbox_user}'") + raise RuntimeError(f"Failed to get GID for user '{sandbox_user}'") from e + + + diff --git a/sandbox/app/core/runners/nodejs/__init__.py b/sandbox/app/core/runners/nodejs/__init__.py new file mode 100644 index 00000000..fa5243b7 --- /dev/null +++ b/sandbox/app/core/runners/nodejs/__init__.py @@ -0,0 +1,3 @@ +from app.core.runners.nodejs.env import release_lib_binary + +release_lib_binary(True) diff --git a/sandbox/app/core/runners/nodejs/env.py b/sandbox/app/core/runners/nodejs/env.py new file mode 100644 index 00000000..8c6a55aa --- /dev/null +++ b/sandbox/app/core/runners/nodejs/env.py @@ -0,0 +1,124 @@ +import asyncio +import ctypes +import os +import shutil +import stat +import tempfile +from pathlib import Path + +from app.logger import get_logger +from app.config import get_config + +logger = get_logger() + +RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libnodejs.so" +LIB_PATH = "/var/sandbox/sandbox-nodejs" +LIB_NAME = "libnodejs.so" + +lib = ctypes.CDLL(RELEASE_LIB_PATH) +lib.get_lib_version_static.restype = ctypes.c_char_p +lib.get_lib_feature_static.restype = ctypes.c_char_p +logger.info(f"Seccomp Env: nodejs, " + f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, " + f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}") + +try: + with open(RELEASE_LIB_PATH, "rb") as f: + _NODEJS_LIB = f.read() +except: + logger.critical("failed to load nodejs lib") + raise + + +def check_lib_avaiable(): + return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) + + +def release_lib_binary(force_remove: bool): + logger.info("init runtime enviroment") + + lib_file = os.path.join(LIB_PATH, LIB_NAME) + if os.path.exists(lib_file): + if force_remove: + try: + os.remove(lib_file) + except OSError: + logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") + raise + + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_NODEJS_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + else: + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_NODEJS_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + + logger.info("nodejs runner environment initialized") + + +async def prepare_nodejs_dependencies_env(): + config = get_config() + + with tempfile.TemporaryDirectory(dir="/") as root_path: + root = Path(root_path) + + env_sh = root / "env.sh" + with open("script/env.sh") as f: + env_sh.write_text(f.read()) + env_sh.chmod(env_sh.stat().st_mode | stat.S_IXUSR) + + shutil.copytree("dependencies/nodejs", os.path.join(LIB_PATH, "node_temp"), dirs_exist_ok=True) + for root, dirs, files in os.walk(os.path.join(LIB_PATH, "node_temp")): + for d in dirs: + os.chmod(os.path.join(root, d), 0o755) + for f in files: + os.chmod(os.path.join(root, f), 0o444) + + for lib_path in config.nodejs_lib_paths: + lib_path = Path(lib_path) + + if not lib_path.exists(): + logger.warning("nodejs lib path %s is not available", lib_path) + continue + + cmd = [ + "bash", + str(env_sh), + str(lib_path), + str(LIB_PATH), + ] + + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + retcode = process.returncode + + if retcode != 0: + logger.error( + f"create env error for file {lib_path}: retcode={retcode}, stderr={stderr.decode()}" + ) diff --git a/sandbox/app/core/runners/nodejs/nodejs_runner.py b/sandbox/app/core/runners/nodejs/nodejs_runner.py new file mode 100644 index 00000000..59560eee --- /dev/null +++ b/sandbox/app/core/runners/nodejs/nodejs_runner.py @@ -0,0 +1,138 @@ +"""Nodejs code runner""" +import asyncio +import os +import uuid +from typing import Optional + +from app.core.executor import CodeExecutor, ExecutionResult +from app.core.runners.nodejs.env import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.logger import get_logger +from app.models import RunnerOptions + +# Nodejs sandbox prescript template +with open("app/core/runners/nodejs/prescript.js") as f: + NODEJS_PRESCRIPT = f.read() + +logger = get_logger() + + +class NodejsRunner(CodeExecutor): + """Node.js code runner with security isolation""" + + def __init__(self): + super().__init__() + + @staticmethod + def init_environment(code: str, preload: str) -> str: + if not check_lib_avaiable(): + release_lib_binary(False) + code_file_name = uuid.uuid4().hex.replace("-", "_") + + script = NODEJS_PRESCRIPT.replace("{{preload}}", preload, 1) + + eval_code = f"eval(Buffer.from('{code}', 'base64').toString('utf-8'))" + script = script.replace("{{code}}", eval_code, 1) + + code_path = f"{LIB_PATH}/node_temp/tmp/{code_file_name}.js" + try: + os.makedirs(os.path.dirname(code_path), mode=0o755, exist_ok=True) + with open(code_path, "w", encoding="utf-8") as f: + f.write(script) + os.chmod(code_path, 0o755) + + except OSError as e: + raise RuntimeError(f"Failed to write {code_path}") from e + + return code_path + + async def run( + self, + code: str, + options: RunnerOptions, + preload: str = "", + timeout: Optional[int] = None + ) -> ExecutionResult: + """Run Python code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code to execute before main code + timeout: Execution timeout in seconds + + Returns: + ExecutionResult with stdout, stderr, and exit code + """ + config = self.config + + if timeout is None: + timeout = config.worker_timeout + + # Check if preload is allowed + if not preload or not config.enable_preload: + preload = "" + script_path = self.init_environment(code, preload) + + try: + # Setup environment + env = { + "UV_USE_IO_URING": "0" + } + + # Add proxy settings if configured + if config.proxy.socks5: + env["HTTPS_PROXY"] = config.proxy.socks5 + env["HTTP_PROXY"] = config.proxy.socks5 + elif config.proxy.https or config.proxy.http: + if config.proxy.https: + env["HTTPS_PROXY"] = config.proxy.https + if config.proxy.http: + env["HTTP_PROXY"] = config.proxy.http + + # Add allowed syscalls if configured + if config.allowed_syscalls: + env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls)) + + process = await asyncio.create_subprocess_exec( + config.nodejs_path, + script_path, + LIB_PATH, + str(config.sandbox_uid), + str(config.sandbox_gid), + options.model_dump_json(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + cwd=LIB_PATH + ) + + # Wait for completion with timeout + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=timeout + ) + + return ExecutionResult( + stdout=stdout.decode('utf-8', errors='replace'), + stderr=stderr.decode('utf-8', errors='replace'), + exit_code=process.returncode + ) + + except asyncio.TimeoutError: + # Kill process on timeout + try: + process.kill() + await process.wait() + except: + pass + + return ExecutionResult( + stdout="", + stderr="Execution timeout", + exit_code=-1, + ) + + finally: + # Cleanup temporary file + self.cleanup_temp_file(script_path) diff --git a/sandbox/app/core/runners/nodejs/prescript.js b/sandbox/app/core/runners/nodejs/prescript.js new file mode 100644 index 00000000..460aa108 --- /dev/null +++ b/sandbox/app/core/runners/nodejs/prescript.js @@ -0,0 +1,31 @@ +let argv = process.argv + +let koffi = require('koffi') + +process.chdir(argv[2]) + +let lib = koffi.load("./libnodejs.so") +/** @type {(uid: number, gid: number, enableNetwork: boolean) => number} */ +let initSeccomp = lib.func('int init_seccomp(int, int, bool)') + +let uid = parseInt(argv[3]) +let gid = parseInt(argv[4]) + +let options = JSON.parse(argv[5]) + +let seccomp_init = initSeccomp(uid, gid, options['enable_network']) +if (seccomp_init !== 0) { + throw `code executor err - ${seccomp_init}` +} + +delete process.argv +argv = undefined +koffi = undefined +lib = undefined +initSeccomp = undefined +uid = undefined +gid = undefined +options = undefined +seccomp_init = undefined + +{{code}} diff --git a/sandbox/app/core/runners/python/__init__.py b/sandbox/app/core/runners/python/__init__.py index 99a56ef7..e1a34906 100644 --- a/sandbox/app/core/runners/python/__init__.py +++ b/sandbox/app/core/runners/python/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- -# Author: Eternity -# @Email: 1533512157@qq.com -# @Time : 2026/1/23 11:27 +from app.core.runners.python.env import release_lib_binary + +release_lib_binary(True) diff --git a/sandbox/app/core/runners/python/env.py b/sandbox/app/core/runners/python/env.py index d82b0522..541acc73 100644 --- a/sandbox/app/core/runners/python/env.py +++ b/sandbox/app/core/runners/python/env.py @@ -1,14 +1,80 @@ import asyncio -import tempfile +import ctypes +import os import stat +import tempfile from pathlib import Path from app.config import get_config -from app.core.runners.python.settings import LIB_PATH from app.logger import get_logger logger = get_logger() +RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libpython.so" +LIB_PATH = "/var/sandbox/sandbox-python" +LIB_NAME = "libpython.so" + +lib = ctypes.CDLL(RELEASE_LIB_PATH) +lib.get_lib_version_static.restype = ctypes.c_char_p +lib.get_lib_feature_static.restype = ctypes.c_char_p +logger.info(f"Seccomp Env: python3, " + f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, " + f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}") + +try: + with open(RELEASE_LIB_PATH, "rb") as f: + _PYTHON_LIB = f.read() +except: + logger.critical("failed to load python lib") + raise + + +def check_lib_avaiable(): + return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) + + +def release_lib_binary(force_remove: bool): + logger.info("init runtime enviroment") + + lib_file = os.path.join(LIB_PATH, LIB_NAME) + if os.path.exists(lib_file): + if force_remove: + try: + os.remove(lib_file) + except OSError: + logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") + raise + + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + else: + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + + logger.info("python runner environment initialized") + async def prepare_python_dependencies_env(): config = get_config() diff --git a/sandbox/app/core/runners/python/prescript.py b/sandbox/app/core/runners/python/prescript.py index 950710ea..b694fe9b 100644 --- a/sandbox/app/core/runners/python/prescript.py +++ b/sandbox/app/core/runners/python/prescript.py @@ -17,7 +17,7 @@ sys.excepthook = excepthook # Load security library if available lib = ctypes.CDLL("./libpython.so") lib.init_seccomp.argtypes = [ctypes.c_uint32, ctypes.c_uint32, ctypes.c_bool] -lib.init_seccomp.restype = None # TODO: raise error info +lib.init_seccomp.restype = ctypes.c_int # Get running path running_path = sys.argv[1] @@ -37,7 +37,10 @@ os.chdir(running_path) {{preload}} # Apply security if library is available -lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}}) +init_status = lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}}) +if init_status != 0: + raise Exception(f"code executor err - {str(init_status)}") +del lib # Decrypt and execute code code = b64decode("{{code}}") diff --git a/sandbox/app/core/runners/python/python_runner.py b/sandbox/app/core/runners/python/python_runner.py index 30792b91..eccd16e0 100644 --- a/sandbox/app/core/runners/python/python_runner.py +++ b/sandbox/app/core/runners/python/python_runner.py @@ -5,10 +5,10 @@ import os import uuid from typing import Optional -from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config +from app.config import get_config from app.core.encryption import generate_key, encrypt_code from app.core.executor import CodeExecutor, ExecutionResult -from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.core.runners.python.env import check_lib_avaiable, release_lib_binary, LIB_PATH from app.logger import get_logger from app.models import RunnerOptions @@ -32,8 +32,8 @@ class PythonRunner(CodeExecutor): config = get_config() code_file_name = uuid.uuid4().hex.replace("-", "_") - script = PYTHON_PRESCRIPT.replace("{{uid}}", str(SANDBOX_USER_ID), 1) - script = script.replace("{{gid}}", str(SANDBOX_GROUP_ID), 1) + script = PYTHON_PRESCRIPT.replace("{{uid}}", str(config.sandbox_uid), 1) + script = script.replace("{{gid}}", str(config.sandbox_gid), 1) script = script.replace( "{{enable_network}}", str(int(options.enable_network and config.enable_network) diff --git a/sandbox/app/core/runners/python/settings.py b/sandbox/app/core/runners/python/settings.py deleted file mode 100644 index aee8827b..00000000 --- a/sandbox/app/core/runners/python/settings.py +++ /dev/null @@ -1,62 +0,0 @@ -import os - -from app.logger import get_logger - -logger = get_logger() - -RELEASE_LIB_PATH = "./lib/seccomp_python/target/release/libpython.so" -LIB_PATH = "/var/sandbox/sandbox-python" -LIB_NAME = "libpython.so" - -try: - with open(RELEASE_LIB_PATH, "rb") as f: - _PYTHON_LIB = f.read() -except: - logger.critical("failed to load python lib") - raise - - -def check_lib_avaiable(): - return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) - - -def release_lib_binary(force_remove: bool): - logger.info("init runtime enviroment") - lib_file = os.path.join(LIB_PATH, LIB_NAME) - if os.path.exists(lib_file): - if force_remove: - try: - os.remove(lib_file) - except OSError: - logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") - raise - - try: - os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) - except OSError: - logger.critical(f"failed to create {LIB_PATH}") - raise - - try: - with open(lib_file, "wb") as f: - f.write(_PYTHON_LIB) - os.chmod(lib_file, 0o755) - except OSError: - logger.critical(f"failed to write {lib_file}") - raise - else: - try: - os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) - except OSError: - logger.critical(f"failed to create {LIB_PATH}") - raise - - try: - with open(lib_file, "wb") as f: - f.write(_PYTHON_LIB) - os.chmod(lib_file, 0o755) - except OSError: - logger.critical(f"failed to write {lib_file}") - raise - - logger.info("python runner environment initialized") diff --git a/sandbox/app/dependencies.py b/sandbox/app/dependencies.py index 6e88aaf2..6fe05ee4 100644 --- a/sandbox/app/dependencies.py +++ b/sandbox/app/dependencies.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Dict from app.config import get_config +from app.core.runners.nodejs.env import prepare_nodejs_dependencies_env from app.core.runners.python.env import prepare_python_dependencies_env from app.logger import get_logger @@ -19,7 +20,10 @@ async def setup_dependencies(): logger.info("Preparing Python dependencies environment...") await prepare_python_dependencies_env() - logger.info("Python dependencies environment ready") + logger.info("Python Environment Ready ....") + logger.info("Preparing Nodejs dependencies environment...") + await prepare_nodejs_dependencies_env() + logger.info("Nodejs Environment Ready ...") except Exception as e: logger.error(f"Failed to setup dependencies: {e}") @@ -36,7 +40,7 @@ async def install_python_dependencies(): config = get_config() # Check if requirements file exists - req_file = Path("dependencies/python-requirements.txt") + req_file = Path("dependencies/python/python-requirements.txt") if not req_file.exists(): logger.warning("Python requirements file not found, skipping installation") return diff --git a/sandbox/app/logger.py b/sandbox/app/logger.py index de2ccc9e..9e63c8e5 100644 --- a/sandbox/app/logger.py +++ b/sandbox/app/logger.py @@ -12,25 +12,27 @@ def setup_logger() -> logging.Logger: """Setup application logger""" global _logger + if _logger is not None: + return _logger + config = get_config() # Create logger _logger = logging.getLogger("sandbox") _logger.setLevel(logging.DEBUG if config.app.debug else logging.INFO) - # Create console handler - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO) + # 只在 logger 没有 handler 时才添加 + if not _logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO) - # Create formatter - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - handler.setFormatter(formatter) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) - # Add handler to logger - _logger.addHandler(handler) + _logger.addHandler(handler) return _logger diff --git a/sandbox/app/middleware/concurrency.py b/sandbox/app/middleware/concurrency.py index 8d8325a4..e931f846 100644 --- a/sandbox/app/middleware/concurrency.py +++ b/sandbox/app/middleware/concurrency.py @@ -1,48 +1,66 @@ -"""Concurrency control middleware""" +""" +Concurrency control middleware +""" import asyncio +from contextlib import asynccontextmanager + from fastapi import HTTPException, status from app.config import get_config -from app.models import error_response +from app.logger import get_logger + +logger = get_logger() -# Global semaphores -_worker_semaphore: None | asyncio.Semaphore = None -_request_counter = 0 -_request_lock = asyncio.Lock() +class ConcurrencyController: + def __init__(self): + self._worker_semaphore: asyncio.Semaphore | None = None + self._request_counter = 0 + self._lock = asyncio.Lock() + + config = get_config() + self.max_requests = config.max_requests + + def init(self): + config = get_config() + self._worker_semaphore = asyncio.Semaphore(config.max_workers) + + async def _acquire_worker(self): + if self._worker_semaphore is None: + self.init() + async with self._worker_semaphore: + yield + + async def _limit_requests(self): + async with self._lock: + logger.info(f"Current requests: {self._request_counter}/{self.max_requests}") + if self._request_counter >= self.max_requests: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "code": 503, + "message": "Too many requests", + "data": None, + } + ) + self._request_counter += 1 + try: + yield + finally: + async with self._lock: + self._request_counter -= 1 + + def acquire_worker(self): + return asynccontextmanager(self._acquire_worker)() + + def limit_requests(self): + return asynccontextmanager(self._limit_requests)() -def init_concurrency_control(): - """Initialize concurrency control""" - global _worker_semaphore - config = get_config() - _worker_semaphore = asyncio.Semaphore(config.max_workers) +concurrency = ConcurrencyController() -async def check_max_requests(): - """Check if max requests limit is reached""" - global _request_counter - config = get_config() - - async with _request_lock: - if _request_counter >= config.max_requests: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=error_response(-503, "Too many requests") - ) - _request_counter += 1 - - try: - yield - finally: - async with _request_lock: - _request_counter -= 1 - - -async def acquire_worker(): - """Acquire a worker slot""" - if _worker_semaphore is None: - init_concurrency_control() - - async with _worker_semaphore: - yield +async def concurrency_guard(): + async with concurrency.limit_requests(): + async with concurrency.acquire_worker(): + yield diff --git a/sandbox/app/services/nodejs_service.py b/sandbox/app/services/nodejs_service.py new file mode 100644 index 00000000..ffd6127b --- /dev/null +++ b/sandbox/app/services/nodejs_service.py @@ -0,0 +1,43 @@ +"""Nodejs execution service""" +import signal + +from app.core.runners.nodejs.nodejs_runner import NodejsRunner +from app.logger import get_logger +from app.models import ( + success_response, + error_response, + RunCodeResponse, + RunnerOptions +) + + +async def run_nodejs_code(code: str, preload: str, options: RunnerOptions): + """Execute Node.js code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code + + Returns: + API response with execution result + """ + logger = get_logger() + + try: + runner = NodejsRunner() + result = await runner.run(code, options, preload) + if result.exit_code == signal.SIGSYS + 0x80: + return error_response(31, "sandbox security policy violation") + + if result.exit_code != 0: + return error_response(500, result.stderr) + + return success_response(RunCodeResponse( + stdout=result.stdout, + stderr=result.stderr + )) + + except Exception as e: + logger.error(f"Python execution failed: {e}", exc_info=True) + return error_response(-500, str(e)) diff --git a/sandbox/config.yaml b/sandbox/config.yaml index d9581b34..26fb9af3 100644 --- a/sandbox/config.yaml +++ b/sandbox/config.yaml @@ -1,13 +1,11 @@ app: - port: 8194 - debug: true key: redbear-sandbox -max_workers: 4 -max_requests: 50 -worker_timeout: 30 +max_workers: 10 +max_requests: 300 +worker_timeout: 15 python_path: /usr/local/bin/python -nodejs_path: /usr/local/bin/node +nodejs_path: /usr/bin/node enable_network: true enable_preload: false python_deps_update_interval: 30m diff --git a/sandbox/dependencies/nodejs/node_modules/.package-lock.json b/sandbox/dependencies/nodejs/node_modules/.package-lock.json new file mode 100644 index 00000000..28b290ef --- /dev/null +++ b/sandbox/dependencies/nodejs/node_modules/.package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "nodejs", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/sandbox/dependencies/nodejs/package-lock.json b/sandbox/dependencies/nodejs/package-lock.json new file mode 100644 index 00000000..28b290ef --- /dev/null +++ b/sandbox/dependencies/nodejs/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "nodejs", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/sandbox/dependencies/nodejs/package.json b/sandbox/dependencies/nodejs/package.json new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/sandbox/dependencies/nodejs/package.json @@ -0,0 +1 @@ +{} diff --git a/sandbox/dependencies/python-requirements.txt b/sandbox/dependencies/python/python-requirements.txt similarity index 100% rename from sandbox/dependencies/python-requirements.txt rename to sandbox/dependencies/python/python-requirements.txt diff --git a/sandbox/lib/seccomp_nodejs/Cargo.lock b/sandbox/lib/seccomp_nodejs/Cargo.lock deleted file mode 100644 index b37698ee..00000000 --- a/sandbox/lib/seccomp_nodejs/Cargo.lock +++ /dev/null @@ -1,7 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "seccomp_nodejs" -version = "0.1.0" diff --git a/sandbox/lib/seccomp_nodejs/Cargo.toml b/sandbox/lib/seccomp_nodejs/Cargo.toml deleted file mode 100644 index a8bd8932..00000000 --- a/sandbox/lib/seccomp_nodejs/Cargo.toml +++ /dev/null @@ -1,6 +0,0 @@ -[package] -name = "seccomp_nodejs" -version = "0.1.0" -edition = "2024" - -[dependencies] \ No newline at end of file diff --git a/sandbox/lib/seccomp_nodejs/src/lib.rs b/sandbox/lib/seccomp_nodejs/src/lib.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/sandbox/lib/seccomp_python/Cargo.lock b/sandbox/lib/seccomp_redbear/Cargo.lock similarity index 92% rename from sandbox/lib/seccomp_python/Cargo.lock rename to sandbox/lib/seccomp_redbear/Cargo.lock index 881ad177..f81d17c0 100644 --- a/sandbox/lib/seccomp_python/Cargo.lock +++ b/sandbox/lib/seccomp_redbear/Cargo.lock @@ -15,8 +15,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60276e2d41bbb68b323e566047a1bfbf952050b157d8b5cdc74c07c1bf4ca3b6" [[package]] -name = "seccomp_python" -version = "0.1.0" +name = "seccomp_redbear" +version = "0.1.1" dependencies = [ "libc", "libseccomp-sys", diff --git a/sandbox/lib/seccomp_python/Cargo.toml b/sandbox/lib/seccomp_redbear/Cargo.toml similarity index 51% rename from sandbox/lib/seccomp_python/Cargo.toml rename to sandbox/lib/seccomp_redbear/Cargo.toml index 07037172..d6535987 100644 --- a/sandbox/lib/seccomp_python/Cargo.toml +++ b/sandbox/lib/seccomp_redbear/Cargo.toml @@ -1,12 +1,17 @@ [package] -name = "seccomp_python" -version = "0.1.0" +name = "seccomp_redbear" +version = "0.1.1" edition = "2024" [lib] -name = "python" +name = "sandbox" crate-type = ["cdylib"] [dependencies] libc = "0.2.180" libseccomp-sys = "0.3.0" + +[features] +default = [] +python3 = [] +nodejs = [] diff --git a/sandbox/lib/seccomp_python/src/lib.rs b/sandbox/lib/seccomp_redbear/src/lib.rs similarity index 82% rename from sandbox/lib/seccomp_python/src/lib.rs rename to sandbox/lib/seccomp_redbear/src/lib.rs index 08b46c54..9de38a56 100644 --- a/sandbox/lib/seccomp_python/src/lib.rs +++ b/sandbox/lib/seccomp_redbear/src/lib.rs @@ -1,13 +1,25 @@ -mod syscalls; +#[cfg(all(feature = "python3", feature = "nodejs"))] +compile_error!("Only one feature can be enabled: either python3 or nodejs, not both!"); -use crate::syscalls::*; -use libc::{chdir, chroot, gid_t, uid_t, c_int}; +#[cfg(not(any(feature = "python3", feature = "nodejs")))] +compile_error!("You must enable one feature: either python3 or nodejs"); + +#[cfg(feature = "python3")] +mod python_syscalls; +#[cfg(feature = "python3")] +use crate::python_syscalls::*; + +#[cfg(feature = "nodejs")] +mod nodejs_syscalls; +#[cfg(feature = "nodejs")] +use crate::nodejs_syscalls::*; + +use libc::{c_char, c_int, chdir, chroot, gid_t, uid_t}; use libseccomp_sys::*; use std::env; use std::ffi::CString; use std::str::FromStr; - /* * get_allowed_syscalls - retrieve allowed syscalls for the sandbox * @enable_network: enable network-related syscalls if non-zero @@ -193,3 +205,20 @@ pub unsafe extern "C" fn init_seccomp(uid: uid_t, gid: gid_t, enable_network: i3 Err(code) => code, } } + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn get_lib_version_static() -> *const c_char { + concat!(env!("CARGO_PKG_VERSION"), "\0").as_ptr() as *const c_char +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn get_lib_feature_static() -> *const c_char { + #[cfg(feature = "python3")] + let s = b"python3\0"; + #[cfg(feature = "nodejs")] + let s = b"nodejs\0"; + #[cfg(not(any(feature = "python3", feature = "nodejs")))] + let s = b"none\0"; + + s.as_ptr() as *const c_char +} diff --git a/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs b/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs new file mode 100644 index 00000000..7cf36664 --- /dev/null +++ b/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs @@ -0,0 +1,74 @@ +// src/nodejs_syscalls.rs + +pub static ALLOW_SYSCALLS: &[i32] = &[ + // File IO + libc::SYS_open as i32, + libc::SYS_write as i32, + libc::SYS_close as i32, + libc::SYS_read as i32, + libc::SYS_openat as i32, + libc::SYS_newfstatat as i32, + libc::SYS_ioctl as i32, + libc::SYS_lseek as i32, + libc::SYS_fstat as i32, + libc::SYS_readlink as i32, + libc::SYS_dup3 as i32, + libc::SYS_fcntl as i32, + libc::SYS_fsync as i32, + // Memory + libc::SYS_mprotect as i32, + libc::SYS_mmap as i32, + libc::SYS_munmap as i32, + libc::SYS_mremap as i32, + libc::SYS_brk as i32, + libc::SYS_madvise as i32, + // Signal + libc::SYS_rt_sigaction as i32, + libc::SYS_rt_sigprocmask as i32, + libc::SYS_sigaltstack as i32, + libc::SYS_rt_sigreturn as i32, + libc::SYS_tgkill as i32, + // Thread + libc::SYS_futex as i32, + libc::SYS_sched_yield as i32, + libc::SYS_set_robust_list as i32, + libc::SYS_rseq as i32, + // User / Group + libc::SYS_getuid as i32, + // Process + libc::SYS_getpid as i32, + libc::SYS_gettid as i32, + libc::SYS_exit as i32, + libc::SYS_exit_group as i32, + libc::SYS_sched_getaffinity as i32, + // Time + libc::SYS_clock_gettime as i32, + libc::SYS_gettimeofday as i32, + libc::SYS_nanosleep as i32, + libc::SYS_time as i32, + // Epoll / Event (I/O multiplexing) + libc::SYS_epoll_ctl as i32, + libc::SYS_epoll_pwait as i32, +]; + +pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[libc::SYS_clone as i32, libc::SYS_clone3 as i32]; + +pub static ALLOW_NETWORK_SYSCALLS: &[i32] = &[ + libc::SYS_socket as i32, + libc::SYS_connect as i32, + libc::SYS_bind as i32, + libc::SYS_listen as i32, + libc::SYS_accept as i32, + libc::SYS_sendto as i32, + libc::SYS_recvfrom as i32, + libc::SYS_getsockname as i32, + libc::SYS_recvmsg as i32, + libc::SYS_getpeername as i32, + libc::SYS_setsockopt as i32, + libc::SYS_ppoll as i32, + libc::SYS_uname as i32, + libc::SYS_sendmsg as i32, + libc::SYS_getsockopt as i32, + libc::SYS_fcntl as i32, + libc::SYS_fstatfs as i32, +]; diff --git a/sandbox/lib/seccomp_python/src/syscalls.rs b/sandbox/lib/seccomp_redbear/src/python_syscalls.rs similarity index 90% rename from sandbox/lib/seccomp_python/src/syscalls.rs rename to sandbox/lib/seccomp_redbear/src/python_syscalls.rs index 961fffac..998ae390 100644 --- a/sandbox/lib/seccomp_python/src/syscalls.rs +++ b/sandbox/lib/seccomp_redbear/src/python_syscalls.rs @@ -1,7 +1,7 @@ -// src/syscalls.rs +// src/python_syscalls.rs pub static ALLOW_SYSCALLS: &[i32] = &[ - // file io + // File IO libc::SYS_read as i32, libc::SYS_write as i32, libc::SYS_openat as i32, @@ -11,48 +11,44 @@ pub static ALLOW_SYSCALLS: &[i32] = &[ libc::SYS_lseek as i32, libc::SYS_getdents64 as i32, libc::SYS_fstat as i32, - - // thread + // Signal + libc::SYS_rt_sigreturn as i32, + libc::SYS_rt_sigaction as i32, + libc::SYS_rt_sigprocmask as i32, + libc::SYS_sigaltstack as i32, + libc::SYS_tgkill as i32, + // Thread libc::SYS_futex as i32, - - // memory + // Memory libc::SYS_mmap as i32, libc::SYS_brk as i32, libc::SYS_mprotect as i32, libc::SYS_munmap as i32, - libc::SYS_rt_sigreturn as i32, libc::SYS_mremap as i32, - - // user / group - libc::SYS_setuid as i32, - libc::SYS_setgid as i32, + // User / Group libc::SYS_getuid as i32, - - // process + // Process libc::SYS_getpid as i32, libc::SYS_getppid as i32, libc::SYS_gettid as i32, libc::SYS_exit as i32, libc::SYS_exit_group as i32, - libc::SYS_tgkill as i32, - libc::SYS_rt_sigaction as i32, libc::SYS_sched_yield as i32, libc::SYS_set_robust_list as i32, libc::SYS_get_robust_list as i32, libc::SYS_rseq as i32, - - // time + // Time libc::SYS_clock_gettime as i32, libc::SYS_gettimeofday as i32, + libc::SYS_time as i32, libc::SYS_nanosleep as i32, + libc::SYS_clock_nanosleep as i32, + // Epoll / Event (I/O multiplexing) libc::SYS_epoll_create1 as i32, libc::SYS_epoll_ctl as i32, - libc::SYS_clock_nanosleep as i32, libc::SYS_pselect6 as i32, - libc::SYS_rt_sigprocmask as i32, - libc::SYS_sigaltstack as i32, + // Randomness libc::SYS_getrandom as i32, - ]; pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[ diff --git a/sandbox/main.py b/sandbox/main.py index fc417563..99b7b0a6 100644 --- a/sandbox/main.py +++ b/sandbox/main.py @@ -11,51 +11,15 @@ from fastapi import FastAPI from app.config import get_config from app.controllers import manager_router +from app.core.runners import init_sandbox_user from app.dependencies import setup_dependencies, update_dependencies_periodically from app.logger import setup_logger, get_logger +setup_logger() +config = get_config() logger = get_logger() -@asynccontextmanager -async def lifespan(app: FastAPI): - """Application lifespan manager""" - logger = get_logger() - - # Startup - logger.info("Starting RedBear Sandbox...") - - # Setup dependencies in background - asyncio.create_task(setup_dependencies()) - - # Start periodic dependency updates - config = get_config() - if config.python_deps_update_interval: - asyncio.create_task(update_dependencies_periodically()) - - yield - - # Shutdown - logger.info("Shutting down Redbear Sandbox...") - - -def create_app() -> FastAPI: - """Create FastAPI application""" - config = get_config() - - app = FastAPI( - title="Sandbox", - description="Secure code execution sandbox", - version="2.0.0", - lifespan=lifespan, - debug=config.app.debug - ) - - app.include_router(manager_router) - - return app - - def check_root_privileges(): """Check if running with root privileges""" if os.geteuid() != 0: @@ -63,35 +27,38 @@ def check_root_privileges(): sys.exit(1) -def main(): - """Main entry point""" - # Check root privileges - check_root_privileges() +check_root_privileges() - # Setup logging - setup_logger() - config = get_config() +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager""" logger = get_logger() - + config = get_config() + # Startup + logger.info("Starting RedBear Sandbox...") logger.info(f"Starting server on port {config.app.port}") logger.info(f"Debug mode: {config.app.debug}") logger.info(f"Max workers: {config.max_workers}") logger.info(f"Max requests: {config.max_requests}") logger.info(f"Network enabled: {config.enable_network}") + init_sandbox_user() + await setup_dependencies() - # Create app - app = create_app() + if config.python_deps_update_interval: + asyncio.create_task(update_dependencies_periodically()) - # Run server - uvicorn.run( - app, - host="0.0.0.0", - port=config.app.port, - log_level="debug" if config.app.debug else "info", - access_log=config.app.debug - ) + yield + # Shutdown + logger.info("Shutting down Redbear Sandbox...") -if __name__ == "__main__": - main() +app = FastAPI( + title="Sandbox", + description="Secure code execution sandbox", + version="0.1.0", + lifespan=lifespan, + debug=config.app.debug +) + +app.include_router(manager_router) diff --git a/web/package.json b/web/package.json index e28e8b56..89800fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,14 @@ "@antv/layout": "^1.2.14-beta.8", "@antv/x6": "^3.0.1", "@antv/x6-react-shape": "^3.0.1", + "@codemirror/lang-cpp": "^6.0.3", + "@codemirror/lang-java": "^6.0.2", + "@codemirror/lang-javascript": "^6.2.4", + "@codemirror/lang-python": "^6.2.1", + "@codemirror/lang-rust": "^6.0.2", + "@codemirror/state": "^6.5.4", + "@codemirror/theme-one-dark": "^6.1.3", + "@codemirror/view": "^6.39.12", "@dnd-kit/core": "^6.3.1", "@dnd-kit/modifiers": "^9.0.0", "@dnd-kit/sortable": "^10.0.0", @@ -25,6 +33,7 @@ "antd": "^5.27.4", "axios": "^1.12.2", "clsx": "^2.1.1", + "codemirror": "^6.0.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.18", @@ -55,6 +64,7 @@ "@tailwindcss/postcss": "^4.1.14", "@tailwindcss/typography": "^0.5.19", "@tailwindcss/vite": "^4.1.14", + "@types/codemirror": "^5.60.17", "@types/crypto-js": "^4.2.2", "@types/js-yaml": "^4.0.9", "@types/node": "^24.6.0", diff --git a/web/src/api/ontology.ts b/web/src/api/ontology.ts new file mode 100644 index 00000000..bb5244e4 --- /dev/null +++ b/web/src/api/ontology.ts @@ -0,0 +1,40 @@ +import { request } from '@/utils/request' +import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData } from '@/views/Ontology/types' + +// Scene list +export const getOntologyScenesSimpleUrl = '/memory/ontology/scenes/simple' +export const getOntologyScenesUrl = '/memory/ontology/scenes' +export const getOntologyScenesList = (data: Query) => { + return request.get(getOntologyScenesUrl, data) +} + +// Create scene +export const createOntologyScene = (data: OntologyModalData) => { + return request.post('/memory/ontology/scene', data) +} +// Update scene +export const updateOntologyScene = (scene_id: string, data: OntologyModalData) => { + return request.put(`/memory/ontology/scene/${scene_id}`, data) +} +// Delete scene +export const deleteOntologyScene = (scene_id: string) => { + return request.delete(`/memory/ontology/scene/${scene_id}`) +} + +// Get class list +export const getOntologyclassesUrl = '/memory/ontology/classes' +export const getOntologyClassList = (data: { scene_id: string; class_name?: string; }) => { + return request.get(getOntologyclassesUrl, data) +} +// Extract ontology types +export const extractOntologyTypes = (data: OntologyClassExtractModalData) => { + return request.post('/memory/ontology/extract', data) +} +// Create ontology class +export const createOntologyClass = (data: OntologyClassModalData) => { + return request.post('/memory/ontology/class', data) +} +// Delete ontology class +export const deleteOntologyClass = (class_id: string) => { + return request.delete(`/memory/ontology/class/${class_id}`) +} diff --git a/web/src/api/prompt.ts b/web/src/api/prompt.ts index 526f50ac..79ea374c 100644 --- a/web/src/api/prompt.ts +++ b/web/src/api/prompt.ts @@ -1,13 +1,26 @@ import { request } from '@/utils/request' import type { AiPromptForm } from '@/views/ApplicationConfig/types' +import type { PromptReleaseData } from '@/views/Prompt/types' import { handleSSE, type SSEMessage } from '@/utils/stream' +// Create session export const createPromptSessions = () => { return request.post(`/prompt/sessions`) } -export const getPrompt = (session_id: string) => { - return request.get(`/prompt/sessions/${session_id}`) -} +// Get prompt optimization export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => { return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage) +} +// Prompt release list +export const getPromptReleaseListUrl = '/prompt/releases/list' +export const getPromptReleaseList = () => { + return request.get(getPromptReleaseListUrl) +} +// Save prompt +export const savePrompt = (data: PromptReleaseData) => { + return request.post('/prompt/releases', data) +} +// Delete prompt +export const deletePrompt = (prompt_id: string) => { + return request.delete(`/prompt/releases/${prompt_id}`) } \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology.svg b/web/src/assets/images/menu/ontology.svg new file mode 100644 index 00000000..9bfda42b --- /dev/null +++ b/web/src/assets/images/menu/ontology.svg @@ -0,0 +1,11 @@ + + + 本体管理备份 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology_active.svg b/web/src/assets/images/menu/ontology_active.svg new file mode 100644 index 00000000..1271c2c3 --- /dev/null +++ b/web/src/assets/images/menu/ontology_active.svg @@ -0,0 +1,11 @@ + + + 本体管理 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt.svg b/web/src/assets/images/menu/prompt.svg new file mode 100644 index 00000000..ffef9a34 --- /dev/null +++ b/web/src/assets/images/menu/prompt.svg @@ -0,0 +1,15 @@ + + + 提示词备份 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt_active.svg b/web/src/assets/images/menu/prompt_active.svg new file mode 100644 index 00000000..ac45e13c --- /dev/null +++ b/web/src/assets/images/menu/prompt_active.svg @@ -0,0 +1,15 @@ + + + 提示词 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/space/neo4j.png b/web/src/assets/images/space/neo4j.png new file mode 100644 index 0000000000000000000000000000000000000000..74fc7a861762c7e7dad4f9d3491b139c8f0583ed GIT binary patch literal 1424 zcmYk*c{tR090%|p6SFB)%aV?LSSeGnt?jNyHoA5b@>JS373o=IT-^xCksP@b#x>3y zVGzwUh6cIM+|4Lgga$*C>AllG`}91o&-ZgYpZ`A3H^9XBu)Lg_96=EBM~@hqNwUee z_<f`l_@ zE}*$0@fwn@Bk2Z`C2k_+7E*2_^)~2tKz9S(9cdm&zl#hnWO^g(9iIP zf8+)rFA(_;@hk`hAt($*Q5cHC!H|eRaU@EjP!f%@Sd=|Mc^sGtVA8-!1S<(IC6d8T z0Xr2Hba2wZNk>&CUS)xs4enFaF@^jPMDwaVi` z8UFKL^HP+Xn~h#5Dvi5q6!U`4*w#(u%^eKyvi`B(^#*J1@kZ*-u}1p@Z7Un|j&|9D z%WAE*ESIsf%ToQBbs&Zk?8v_IE=#;J?1BHwo`M&9!h*{Rj2iPc%9!dmoujP^$vAIm zb#0yYgR)ka+3 z`!I_gzb$j?nNnW9E!u;jxr|^Z*4;n8&S}Jfdrn7H=hh~B*H+^oQcaxQ^_u6NJr5NI zHE0uYtuFeveJWD<>J3AkPb_`iN;awLs=nzfNmtXF-pS(17~Hly=gdji?JpQrktv_s zU~@@WpCRM6RjyT;U>lO7YtFLce%-Z%sFZiR)FZ7ykxP?eCz+gD%)i;j*&Mfoa5!Bp z>tUg?V4`HdMcR|8Bs^a6Q@K&ac7^%2&zwg7eJ#03p3tzDk(^uq_u-Mo>r$Ybl~9AcOB?s`iS1%1e*E)P6N(wVq+o)xOf+ku^RF z%H;f!LgU2`$6|9XJBRSJU)|e(x5Pq++MI4>uH6+e^!nmtnI9>ZMBX*kaY&b}M`o2c zon)%jdY&ENh}YW-7Cu}cr*zJ_&o4Sx_7^bJo9rLSPIV{n8AD;b>t8|^#ffW=D3-d< zSVt^dH$Z>D{uMQ`^mLsqs~=%(GnnyeOz zY=F??`idJzJd0+v6@zO$6?)f&oAJngG^22-m%H~>Wh`F^o=>AYd{$6oMM&Sdm2WBc z&_!`ubg1-L>AHOY{-GrX)O71!-$ox{le({9TcE%oRA3q@xEL#Pq6xfGMUOHCNx7nDg`$d5Q43R~ euhK3SeyOwlJ*XRqsf+tT>D9jpq zLKZG$!8{8VuVDEKR0b#UIdRK{8Ego#qcb_wGvz}#f?(jB$UCc3{=8fc)x|uTihzgzv#G4$K49}R^VPG z{Howzg@9@V)ga^@LJ76_cP;MM;Xxh3>JeU#hz2}tz@tVyZbW1gqM8xif|wROZ3V3r zv8{+}Lqa>AwIi_ue{|r*2mINIq)sGvA*Bncgl?pDBfT3hdyvtK%wD|e!|Oi0=|^@y zay}w|0EL4n8ba|9N=H!j3AN*>8%M(g8Yj^-iRMYPd`2r_3h$@TMwrHjFX);<_gD0M zMK1$=4D=Ib@o^RdggFe(VTdq~;dzV@7BISizZUR`i7_U|7csen&r6tE#`H42tYBsZ zUs+(VFuRJmRm>CESYU%mSi|BPmcC*68&=lATE{8}Y!1F{V4VvN7aN34a5u3@!1qwH zreqSgB7V!<%!0IKtIU2)!w_0*RpY?K;?nX;_aGmQMB>*VpFL%1Q~M^Kp(1%fvid=b zzP1UOTqv1LH*QL2^Cht@_EuQS8(fl-no0Xz+{t2ZzD-009Q+gu4%80o7py)%&mK15 z4EedY8%z%ugnXg}{rOvNn0lvSc!Yqkppe*9Yw`ypYK^$;TaRs;T6(rCCdr|ER;|K2 zxUpXK%D-z!9ck*SA7xR;e8XHG2ahSM2{v|U+f1s~c+_@o1a*3KCn*~WvJxYWo>w=K z{%$n6@6>65D~@4e4Lf91&&uTMb9}s+tlDxKV?{)T|IA{Rq!{(ikZwUryrW0(jv$SB zY2nr$880;o*}byrWZqOsgr2=TLyaOSp;L14U|=Hcw;Ubr4^HGtEu%0MJAWSO;gW@9 zCFgCTj^j3YRL3OYpm>!+s|LH5+o{7&lBG8U20Nao-k^KPY&SeC!>75R6gi_niyI9U z8FUumk*AMRFKtGJY`Da{a}!l{IK=yYQ;lP15N|TaQqL~#&Rmk~%yy3uAt~-NDBa#M zqMF69j&`wrG-~?nc&=qY%Y0H}&us4>Tlbv#v8#l^tKWOlOL15J)pOS}V(9J^pKzbx zc*pF6jlSuT$JWQ6s%DFI>!g03A$d;oobo)j=`qwsrZ)(AyD8ZIOE~TkdLc&6B1iV+y2m#;`gM{%4@ief3-8H`qb%Lx`PLv2-SOS+!s_x z;;ktu`sgCLZ|!H7nz6sBG7s+*FmsMM??o3qLD5Te;{8=S=w)uRyKG$0I}M4on09&N zI@Mn%_ow(At>zC*TMCR4-)*>PuHm&2EO;$x@h6`{!&6!cC!Be2Qtmq26)BWo_$%+n zBHsz~ZMui{M(&E(?J6v=EyjF{?jiNi&;e1Y1N#1~tIW&hQc+#WLRT`H7|RCLA)6EG z{!)WIJi2{JrjHIf-naAYm>;`SF%hl2hvxm%sGAjSxz+JuVk)K6rN8Z7`M|Mhp61+L zGkz-)753k_11TKZr3Tk!)%D~m^%0p2h5U3Wj{X>x*5${I>$_*%t|*%!NxiOPt})>9 zNmhpz=iF=|WfK0eLQ^)Ya@@g)E_e0FgAnDl!gn?`zua~lwU_G%nP5^9I-~Uz{g_u6 z#%yq1K${HT7R!Hpb-*k8sMw6Ldg#KOFM zlH+v7m)ww*`uGQ{$M&ONi$`V|^=g%Y5rGbF*M!RwTn*9}XO9TekInlXePuO5)54EYcD-{Qak literal 0 HcmV?d00001 diff --git a/web/src/components/CodeMirrorEditor/index.tsx b/web/src/components/CodeMirrorEditor/index.tsx new file mode 100644 index 00000000..e100b75b --- /dev/null +++ b/web/src/components/CodeMirrorEditor/index.tsx @@ -0,0 +1,150 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 17:20:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 17:20:52 + */ +import { useEffect, useRef, useMemo } from 'react'; +import { EditorView, basicSetup } from 'codemirror'; +import { EditorState } from '@codemirror/state'; +import { python } from '@codemirror/lang-python'; +import { javascript } from '@codemirror/lang-javascript'; +import { java } from '@codemirror/lang-java'; +import { cpp } from '@codemirror/lang-cpp'; +import { rust } from '@codemirror/lang-rust'; +import { oneDark } from '@codemirror/theme-one-dark'; + +/** + * Props for the CodeMirrorEditor component + * @property {string} value - The initial code content to display in the editor + * @property {string} language - Programming language for syntax highlighting (python, python3, javascript, typescript, java, cpp, c, rust) + * @property {function} onChange - Callback function triggered when editor content changes, receives the new code value + * @property {string} theme - Editor theme, either 'light' or 'dark' + * @property {boolean} readOnly - Whether the editor is read-only + * @property {string} height - Custom height for the editor + * @property {string} size - Predefined size preset: 'default' (120px min-height, 14px font) or 'small' (60px min-height, 12px font) + */ +interface CodeMirrorEditorProps { + value?: string; + language?: 'python' | 'python3' | 'javascript' | 'typescript' | 'java' | 'cpp' | 'c' | 'rust'; + onChange?: (value: string) => void; + theme?: 'light' | 'dark'; + readOnly?: boolean; + height?: string; + size?: 'default' | 'small'; +} + +/** + * Map of language identifiers to their corresponding CodeMirror language extensions + * Supports multiple programming languages with syntax highlighting + */ +const languageExtensions: Record = { + python: python(), + python3: python(), + javascript: javascript(), + typescript: javascript({ typescript: true }), + java: java(), + cpp: cpp(), + c: cpp(), + rust: rust(), +}; + +/** + * CodeMirrorEditor - A React wrapper component for CodeMirror 6 editor + * Provides a code editor with syntax highlighting, theme support, and customizable sizing + * Used in workflow code execution nodes for editing Python and JavaScript code + */ +const CodeMirrorEditor = ({ + value = '', + language = 'javascript', + onChange, + theme = 'light', + readOnly = false, + size, +}: CodeMirrorEditorProps) => { + // Reference to the DOM element that will contain the editor + const editorRef = useRef(null); + // Reference to the CodeMirror EditorView instance + const viewRef = useRef(null); + + /** + * Initialize CodeMirror editor when component mounts or when language/theme/readOnly changes + * Sets up extensions for syntax highlighting, change listeners, and theme + */ + useEffect(() => { + if (!editorRef.current) return; + + // Get the appropriate language extension, fallback to JavaScript if not found + const langExtension = languageExtensions[language] || languageExtensions.javascript; + + // Configure editor extensions + const extensions = [ + basicSetup, // Basic editor features (line numbers, bracket matching, etc.) + langExtension, // Language-specific syntax highlighting + // Listen for document changes and trigger onChange callback + EditorView.updateListener.of((update) => { + if (update.docChanged && onChange) { + onChange(update.state.doc.toString()); + } + }), + EditorState.readOnly.of(readOnly), // Set read-only mode + ]; + + // Apply dark theme if specified + if (theme === 'dark') { + extensions.push(oneDark); + } + + // Create editor state with initial value and extensions + const state = EditorState.create({ + doc: value, + extensions, + }); + + // Create and mount the editor view + viewRef.current = new EditorView({ + state, + parent: editorRef.current, + }); + + // Cleanup: destroy editor instance when component unmounts or dependencies change + return () => { + viewRef.current?.destroy(); + }; + }, [language, theme, readOnly]); + + /** + * Update editor content when the value prop changes externally + * Only updates if the new value differs from current editor content + */ + useEffect(() => { + if (viewRef.current && value !== viewRef.current.state.doc.toString()) { + viewRef.current.dispatch({ + changes: { + from: 0, + to: viewRef.current.state.doc.length, + insert: value, + }, + }); + } + }, [value]); + + // Calculate minimum height based on size prop: small (60px) or default (120px) + const minHeight = useMemo(() => { + return `${size === 'small' ? 60 : 120}px` + }, [size]) + + // Calculate font size based on size prop: small (12px) or default (14px) + const fontSize = useMemo(() => { + return `${size === 'small' ? 12 : 14}px` + }, [size]) + + // Calculate line height based on size prop: small (16px) or default (20px) + const lineHeight = useMemo(() => { + return `${size === 'small' ? 16 : 20}px` + }, [size]) + + return
; +}; + +export default CodeMirrorEditor; diff --git a/web/src/components/CustomSelect/index.tsx b/web/src/components/CustomSelect/index.tsx index 1887d635..f93014c9 100644 --- a/web/src/components/CustomSelect/index.tsx +++ b/web/src/components/CustomSelect/index.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState, type FC, type Key } from 'react'; +import { useEffect, useState, useMemo, type FC, type Key } from 'react'; import { Select } from 'antd'; import type { SelectProps, DefaultOptionType } from 'antd/es/select'; import { useTranslation } from 'react-i18next'; @@ -47,13 +47,14 @@ const CustomSelect: FC = ({ }) => { const { t } = useTranslation(); const [options, setOptions] = useState([]); + const memoizedParams = useMemo(() => params, [JSON.stringify(params)]); useEffect(() => { - request.get>(url, params).then((res) => { + request.get>(url, memoizedParams).then((res) => { const data = Array.isArray(res) ? res : res?.items || []; setOptions(data); }); - }, [url, params]); + }, [url, memoizedParams]); const displayOptions = format ? format(options) : options; diff --git a/web/src/components/Empty/BodyWrapper.tsx b/web/src/components/Empty/BodyWrapper.tsx index f9978184..9cdeb0e8 100644 --- a/web/src/components/Empty/BodyWrapper.tsx +++ b/web/src/components/Empty/BodyWrapper.tsx @@ -1,6 +1,6 @@ import type { FC, ReactNode } from 'react' -import { Skeleton } from 'antd' -import Empty from './index' +import PageEmpty from './PageEmpty' +import PageLoading from './PageLoading' interface BodyWrapperProps { children: ReactNode @@ -9,10 +9,10 @@ interface BodyWrapperProps { } const BodyWrapper: FC = ({ children, loading = false, empty }) => { if (loading) { - return + return } if (!loading && empty) { - return + return } return children } diff --git a/web/src/components/Markdown/index.tsx b/web/src/components/Markdown/index.tsx index 58650207..1a2c765d 100644 --- a/web/src/components/Markdown/index.tsx +++ b/web/src/components/Markdown/index.tsx @@ -19,6 +19,7 @@ interface RbMarkdownProps { showHtmlComments?: boolean; // 是否显示 HTML 注释,默认为 false(隐藏) editable?: boolean; // 是否可编辑,默认为 false onContentChange?: (content: string) => void; // 内容变化回调 + className?: string; } const components = { @@ -50,7 +51,7 @@ const components = { audio: ({ src, ...props }: any) => , a: ({ href, children, ...props }: any) => {children}, button: ({ children }: any) => {[children]}, - table: ({ children, ...props }: any) => {children}
, + table: ({ children, ...props }: any) =>
{children}
, tr: ({ children, ...props }: any) => {children}, th: ({ children, ...props }: any) => {children}, td: ({ children, ...props }: any) => {children}, @@ -98,6 +99,7 @@ const RbMarkdown: FC = ({ showHtmlComments = false, editable = false, onContentChange, + className }) => { const [editContent, setEditContent] = useState(content) const textareaRef = useRef(null) @@ -162,7 +164,7 @@ const RbMarkdown: FC = ({ // 预览模式 return ( -
+