Compare commits

...

310 Commits

Author SHA1 Message Date
Mark
a7eaf563d7 Merge pull request #352 from SuanmoSuanyangTechnology/fix/model_base
fix(model)
2026-02-06 17:55:03 +08:00
Timebomb2018
4c7809ce4a fix(model): change the "vl" model type of dashscope to "chat" 2026-02-06 17:51:52 +08:00
Timebomb2018
51847955cd fix(model): change the "vl" model type of dashscope to "chat" 2026-02-06 17:48:01 +08:00
yingzhao
52260f469a Merge pull request #342 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): ui update
2026-02-06 13:45:53 +08:00
zhaoying
c566d22836 fix(web): ui update 2026-02-06 13:45:03 +08:00
yingzhao
62c557deae Merge pull request #340 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-02-06 12:32:52 +08:00
zhaoying
677a603835 fix(web): update text 2026-02-06 12:15:49 +08:00
zhaoying
447d8790ad fix(web): ui update 2026-02-06 12:02:21 +08:00
yingzhao
7b3bf41120 Merge pull request #338 from SuanmoSuanyangTechnology/fix/release_web_zy
Revert "feat(web): move prompt menu"
2026-02-06 11:13:35 +08:00
zhaoying
fe3c31c08c Revert "feat(web): move prompt menu"
This reverts commit 9e6e8f50f8.
2026-02-06 11:11:40 +08:00
lixinyue11
4e7ab3d7e3 Fix/release memory bug (#335)
* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

* writer_graph_bug/fix

* writer_graph_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-05 17:27:28 +08:00
乐力齐
47b25d7a26 Fix/fact summary (#333)
* [fix]Disable the contents related to fact_summary

* [fix]Disable the contents related to fact_summary

* [fix]Modify the code based on the AI review
2026-02-05 15:56:43 +08:00
lixinyue11
aca7d25001 Fix/release memory bug (#332)
* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-05 15:22:15 +08:00
yingzhao
dfa7a2d4cf Merge pull request #327 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): markdown table ui update
2026-02-05 13:58:11 +08:00
zhaoying
169e01276d fix(web): markdown table ui update 2026-02-05 13:57:25 +08:00
乐力齐
07e698265e Fix/writer memory bug (#326)
* [fix]Fix the bug

* [fix]Fix the bug

* [fix]Correct the direction indication.
2026-02-05 13:50:04 +08:00
lixinyue11
46ed7e38bf Fix/release memory bug (#324)
* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-05 12:11:45 +08:00
lixinyue11
3364374dc6 Write Missing None (#321)
* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-02-05 10:50:10 +08:00
Mark
f125d11b6d Merge pull request #318 from SuanmoSuanyangTechnology/fix/release_memory_bug
Fix/release memory bug
2026-02-04 20:29:12 +08:00
lixinyue
657d48a5f9 Multiple independent transactions - single transaction 2026-02-04 20:25:45 +08:00
lixinyue
3735bdde19 Multiple independent transactions - single transaction 2026-02-04 20:20:45 +08:00
lixinyue
3f906d81cb Multiple independent transactions - single transaction 2026-02-04 20:19:04 +08:00
lixinyue
7c1f622797 Multiple independent transactions - single transaction 2026-02-04 20:11:05 +08:00
yingzhao
c96df6bfa5 Merge pull request #311 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-02-04 18:38:14 +08:00
zhaoying
9e6e8f50f8 feat(web): move prompt menu 2026-02-04 18:36:45 +08:00
Mark
88b89ef315 Merge pull request #314 from SuanmoSuanyangTechnology/pref/workflow-token
feat(workflow): add token usage statistics for question classifier and parameter extraction
2026-02-04 18:10:22 +08:00
Mark
cc1528f550 Merge pull request #312 from SuanmoSuanyangTechnology/fix/release_memory_bug
Fix/release memory bug
2026-02-04 18:08:12 +08:00
Eternity
1c8a83140b feat(workflow): add token usage statistics for question classifier and parameter extraction 2026-02-04 18:08:02 +08:00
lixinyue
34276e2066 knowledge_retrieval/bug/fix 2026-02-04 18:06:56 +08:00
lixinyue
918e7285c4 knowledge_retrieval/bug/fix 2026-02-04 18:01:05 +08:00
lixinyue
056d422c71 Merge branch 'refs/heads/release/v0.2.3' into fix/release_memory_bug 2026-02-04 18:00:58 +08:00
lixinyue
5ee54f4e0e knowledge_retrieval/bug/fix 2026-02-04 17:57:43 +08:00
zhaoying
260c75e70c fix(web): ui update 2026-02-04 17:47:12 +08:00
Mark
2d7401922f Merge pull request #310 from SuanmoSuanyangTechnology/fix/memoryConfig-ontology
Fix/memory config ontology
2026-02-04 17:46:00 +08:00
zhaoying
8c7a1348cf feat(web): update memory config ontology api 2026-02-04 17:41:53 +08:00
lanceyq
24fbdbd716 [changes]Modify the code based on the AI review 2026-02-04 17:40:19 +08:00
lanceyq
aad8f0e36b [changes]Modify the description of the time for the recent event 2026-02-04 17:23:52 +08:00
yingzhao
15cad44f08 Merge pull request #309 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): replace code editor
2026-02-04 17:22:11 +08:00
zhaoying
0271454671 fix(web): replace code editor 2026-02-04 17:21:04 +08:00
lanceyq
d0ddf288ca [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios 2026-02-04 17:10:35 +08:00
Mark
bc250ac377 Merge pull request #308 from SuanmoSuanyangTechnology/fix/release_memory_bug
Fix/release memory bug
2026-02-04 17:08:20 +08:00
lixinyue
7922fc3b0e knowledge_retrieval/bug/fix 2026-02-04 15:53:13 +08:00
lixinyue
514c19a247 knowledge_retrieval/bug/fix 2026-02-04 15:51:13 +08:00
lixinyue
41550d4a41 knowledge_retrieval/bug/fix 2026-02-04 15:44:26 +08:00
lixinyue
33cc3c1c3f Merge branch 'refs/heads/release/v0.2.3' into fix/release_memory_bug 2026-02-04 15:44:03 +08:00
lixinyue11
8f0a1d9c6e Fix/release memory bug (#306)
* memory_BUG_fix

* memory_BUG

* memory_BUG_long_term

* memory_BUG_long_term

* memory_BUG_long_term
2026-02-04 14:34:00 +08:00
lixinyue
72b5e5cf8e memory_BUG_long_term 2026-02-04 14:24:50 +08:00
lixinyue
62aba2dd38 memory_BUG_long_term 2026-02-04 14:21:49 +08:00
Mark
cdd6b80089 Merge pull request #305 from SuanmoSuanyangTechnology/fix/ontology-v1
Fix/ontology v1
2026-02-04 14:11:57 +08:00
lanceyq
333836f5e7 [changes] 2026-02-04 14:08:09 +08:00
lixinyue
2d28b4b05c memory_BUG_long_term 2026-02-04 13:54:32 +08:00
Mark
48aca996ff Merge pull request #300 from SuanmoSuanyangTechnology/fix/workflow-code
fix(workflow): switch code input encoding to base64+URL encoding
2026-02-04 13:46:00 +08:00
lixinyue
c8c7e9b304 memory_BUG 2026-02-04 13:45:10 +08:00
Mark
97ff023995 Merge pull request #302 from SuanmoSuanyangTechnology/fix/app_statistic
fix(app)
2026-02-04 13:45:09 +08:00
lanceyq
34f0c3b90c [changes]Active status filtering logic, API Key selection strategy 2026-02-04 13:44:07 +08:00
lixinyue
7c2902d2b8 Merge branch 'refs/heads/release/v0.2.3' into fix/release_memory_bug
# Conflicts:
#	api/app/core/memory/agent/langgraph_graph/write_graph.py
2026-02-04 13:43:15 +08:00
yingzhao
8e41afdffc Merge pull request #304 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-02-04 13:39:51 +08:00
zhaoying
7268886294 fix(web): language editor support paste 2026-02-04 13:38:58 +08:00
zhaoying
cbae900866 fix(web): save add session update 2026-02-04 13:37:49 +08:00
lanceyq
ffff138a6f [changes]Attribute security access, secure numerical conversion, unified use of local variables 2026-02-04 13:34:22 +08:00
lanceyq
88c95db8d0 [add]The main project adds multi-API Key load balancing. 2026-02-04 13:34:22 +08:00
Eternity
bc36b79105 fix(workflow): switch code input encoding to base64+URL encoding 2026-02-04 12:28:28 +08:00
Timebomb2018
5694bc0230 fix(fix the key of the app's token): 2026-02-04 12:27:14 +08:00
Mark
ead3080b2b Merge pull request #297 from SuanmoSuanyangTechnology/fix/app_statistic
fix(app)
2026-02-04 12:11:51 +08:00
Timebomb2018
21eae29bb7 feat(app): modify the key of the token 2026-02-04 12:07:59 +08:00
yingzhao
406740b524 Merge pull request #296 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-02-04 12:00:17 +08:00
zhaoying
9d30bc4062 fix(web): space icon required 2026-02-04 11:59:27 +08:00
zhaoying
fad91b64ab fix(web): prompt add disabled 2026-02-04 11:52:34 +08:00
Mark
2132e71a81 Merge pull request #295 from SuanmoSuanyangTechnology/fix/workflow-code
fix(workflow): fix argument passing in code execution nodes
2026-02-04 11:25:55 +08:00
Eternity
24dafa7359 fix(workflow): fix argument passing in code execution nodes 2026-02-04 11:13:28 +08:00
yingzhao
9a3c74fb64 Merge pull request #293 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): PageScrollList style update
2026-02-03 20:03:58 +08:00
zhaoying
f571f0688a fix(web): PageScrollList style update 2026-02-03 19:55:09 +08:00
Ke Sun
3efb3e8a35 fix(memory): add Redis session validation
- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing
2026-02-03 18:50:59 +08:00
乐力齐
cfcb278406 Ontology v1 bug (#291)
* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]Set the default values of the memory configuration model based on the spatial model.

* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]
2026-02-03 18:42:54 +08:00
yingzhao
dc0d34c281 Merge pull request #288 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-02-03 17:24:18 +08:00
zhaoying
72076c218f fix(web): PageScrollList loading update 2026-02-03 17:23:40 +08:00
zhaoying
151fd3b950 fix(web): PageScrollList loading update 2026-02-03 17:22:58 +08:00
Ke Sun
f27de7df35 feat(memory): add long-term storage task routing and batching 2026-02-03 15:52:45 +08:00
Mark
9f2b6390b0 Merge pull request #285 from SuanmoSuanyangTechnology/refactor/workflow-templates
refactor(workflow): relocate template directory into workflow
2026-02-03 15:34:42 +08:00
Eternity
e196f86e30 refactor(workflow): relocate template directory into workflow 2026-02-03 15:24:16 +08:00
yingzhao
ec41d45234 Merge pull request #284 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): remove delete confirm content
2026-02-03 15:21:31 +08:00
zhaoying
567d1ba18b fix(web): remove delete confirm content 2026-02-03 15:20:31 +08:00
yingzhao
5e47fc45ab Merge pull request #280 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): prompt history remove pageLoading
2026-02-03 10:32:02 +08:00
Eternity
b471d56a86 fix(prompt): remove hard-coded import of prompt file paths (#279)
* Fix/develop memory bug (#274)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix(web): update retrieve_type key

* Fix/develop memory bug (#276)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix_timeline_memories

* write_gragp/bug_fix

* write_gragp/bug_fix

* write_gragp/bug_fix

* chore(celery): disable periodic task scheduling

* fix(prompt): remove hard-coded import of prompt file paths

---------

Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>
2026-02-03 10:29:51 +08:00
zhaoying
61f8029205 fix(web): prompt history remove pageLoading 2026-02-03 10:27:43 +08:00
lixinyue
1aff4eda67 memory_BUG_fix 2026-02-02 20:31:45 +08:00
Mark
5d5351f0bc Merge pull request #277 from SuanmoSuanyangTechnology/fix/token
feat(app)
2026-02-02 19:06:14 +08:00
Timebomb2018
1224802ac6 feat(app and model): token consumption statistics of the cluster 2026-02-02 19:01:11 +08:00
Mark
3464573f17 Merge pull request #273 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(app and model)
2026-02-02 11:55:20 +08:00
yingzhao
9cf49c9c75 Merge pull request #272 from SuanmoSuanyangTechnology/feature/space_zy
Feature/space zy
2026-02-02 11:50:31 +08:00
lixinyue11
4e837cb90c Add/develop memory (#264)
* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 新增长期记忆功能

* 新增长期记忆功能

* 新增长期记忆功能

* 知识库检索多余字段

* 长期
2026-02-02 11:50:23 +08:00
Timebomb2018
e4fb58496b feat(app and model): token consumption statistics 2026-02-02 11:49:44 +08:00
zhaoying
15a254c0cd feat(web): create space add icon 2026-02-02 11:49:32 +08:00
zhaoying
d62746fc8c feat(web): RadioGroupCard support block mode 2026-02-02 11:49:24 +08:00
yingzhao
4b8b6fe407 Merge pull request #271 from SuanmoSuanyangTechnology/fix/customSelect_zy
fix(web): Restructure the CustomSelect component, repair the interfac…
2026-02-02 11:47:36 +08:00
zhaoying
6754834eb3 fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated 2026-02-02 11:40:19 +08:00
yingzhao
be98db561d Merge pull request #270 from SuanmoSuanyangTechnology/fix/develop_chat_zy
Fix/develop chat zy
2026-02-02 10:41:02 +08:00
zhaoying
574d0afc72 feat(web): show code node 2026-02-02 10:28:21 +08:00
zhaoying
31c8ad611c fix: chat conversation_id add node_start 2026-02-02 10:27:20 +08:00
Mark
b23730388d Merge branch 'release/v0.2.2' into develop 2026-01-31 15:24:20 +08:00
lixinyue11
1b853aa893 隐性+情绪,BUG遗漏 (#267) 2026-01-30 19:09:43 +08:00
Mark
36cb0a12ad [add] migration script 2026-01-30 18:56:44 +08:00
lixinyue11
5439eacf2d Fix/develop memory reflex (#265)
* 遗漏的历史映射

* 遗漏的历史映射

* 反思后台报错处理
2026-01-30 18:46:16 +08:00
乐力齐
2687c3b80e Fix/v022 bug (#263)
* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [add]Add scene_id

* [fix]Based on the AI review to fix the code
2026-01-30 18:02:45 +08:00
yingzhao
fa009327ad Merge pull request #262 from SuanmoSuanyangTechnology/feature/ontology_zy
fix(web): conflict resolve
2026-01-30 17:09:46 +08:00
zhaoying
838bd46e83 fix(web): conflict resolve 2026-01-30 17:09:05 +08:00
Ke Sun
ccc2009aa8 feat(celery): add dedicated periodic tasks worker and queue (#261) 2026-01-30 15:31:48 +08:00
Mark
d9aba92314 [add] migration script 2026-01-30 15:27:13 +08:00
乐力齐
696b0475a8 Feature/ontology class clean (#249)
* [add] Complete ontology engineering feature implementation

* [add] Add ontology feature integration and validation utilities

* [add] Add OWL validator and validation utilities

* [fix] Add missing render_ontology_extraction_prompt function

* [fix]Add dependencies, fix functionality
2026-01-30 15:16:39 +08:00
Ke Sun
e7370489e8 Release/v0.2.2 (#260)
* [modify] migration script

* [add] migration script

* fix(web): change form message

* fix(web): the memoryContent field is compatible with numbers and strings

* feat(web): code node hidden

* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.

---------

Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
2026-01-30 15:10:22 +08:00
yingzhao
f1503b2238 Merge pull request #256 from SuanmoSuanyangTechnology/feature/ontology_zy
Feature/ontology zy
2026-01-30 14:26:49 +08:00
yingzhao
cd4661e878 Merge branch 'develop' into feature/ontology_zy 2026-01-30 14:26:27 +08:00
Mark
364e01ec7a Merge pull request #255 from SuanmoSuanyangTechnology/fix/model_TimeBomb
fix(model)
2026-01-30 14:26:25 +08:00
Timebomb2018
ffb7b0ba38 fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.
2026-01-30 14:23:35 +08:00
yingzhao
22151eb49b Merge pull request #254 from SuanmoSuanyangTechnology/feature/prompt_zy
Feature/prompt zy
2026-01-30 14:23:00 +08:00
Mark
d0354345f6 Merge pull request #227 from SuanmoSuanyangTechnology/feature/prompt-release
feat(prompt): add history tracking for prompt releases
2026-01-30 14:22:14 +08:00
Mark
b1e61eb1e4 Merge pull request #250 from SuanmoSuanyangTechnology/featrue/sandbox-nodejs
feat(sandbox): add Node.js code execution support to sandbox
2026-01-30 14:21:16 +08:00
Eternity
36e0ed15b6 feat(sandbox): add Node.js code execution support to sandbox 2026-01-30 14:15:42 +08:00
yingzhao
095dfc2879 Merge pull request #253 from SuanmoSuanyangTechnology/fix/codeNode_zy
feat(web): code node hidden
2026-01-30 13:51:06 +08:00
yingzhao
17dea9433e Merge pull request #252 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): change form message
2026-01-30 13:50:45 +08:00
yingzhao
c285444e2f Merge pull request #251 from SuanmoSuanyangTechnology/feature/memoryApi_zy
fix(web): the memoryContent field is compatible with numbers and strings
2026-01-30 13:50:28 +08:00
zhaoying
8ba402d080 feat(web): code node hidden 2026-01-30 13:47:34 +08:00
zhaoying
88ab86734d fix(web): the memoryContent field is compatible with numbers and strings 2026-01-30 12:19:23 +08:00
Ke Sun
504d87b0b0 feat(tasks): add celery task configuration for periodic jobs
- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs
2026-01-30 12:14:39 +08:00
zhaoying
b0d5818351 fix(web): change form message 2026-01-30 12:08:36 +08:00
Mark
8826a01d32 [add] migration script 2026-01-30 11:17:20 +08:00
Mark
a651ae6ed4 [modify] migration script 2026-01-29 20:15:25 +08:00
lixinyue11
ee50b25d06 Add/develop memory (#247)
* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射
2026-01-29 19:27:02 +08:00
yingzhao
a67be85858 Merge pull request #245 from SuanmoSuanyangTechnology/feature/model_zy
feat(web): model ui update
2026-01-29 19:05:39 +08:00
zhaoying
59c5a3973a feat(web): model ui update 2026-01-29 19:04:57 +08:00
Mark
d76d7343ff Merge pull request #244 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(model)
2026-01-29 18:09:40 +08:00
Timebomb2018
2b9638e7d3 fix(model): bug fix 2026-01-29 18:06:32 +08:00
lixinyue11
3459a73705 Add/develop memory (#243)
* 遗漏的历史映射

* 遗漏的历史映射
2026-01-29 17:57:27 +08:00
yingzhao
bd480a466b Merge pull request #242 from SuanmoSuanyangTechnology/feature/model_zy
feat(web): model ui update
2026-01-29 17:51:41 +08:00
zhaoying
4c34cb55b6 feat(web): model ui update 2026-01-29 17:50:57 +08:00
yingzhao
e137e4a38a Merge pull request #241 from SuanmoSuanyangTechnology/feature/model_zy
feat(web): model ui update
2026-01-29 17:36:41 +08:00
zhaoying
b5989bbc25 feat(web): model ui update 2026-01-29 17:35:54 +08:00
Mark
c31ff7ceef Merge pull request #240 from SuanmoSuanyangTechnology/add/develop_memory
Add/develop memory
2026-01-29 17:28:17 +08:00
zhaoying
9206c7642a feat(web): memory management add scene 2026-01-29 17:13:30 +08:00
zhaoying
d1b4f2b6c2 feat(web): add Ontology menu 2026-01-29 17:13:19 +08:00
lixinyue
75066f2827 遗漏的历史映射 2026-01-29 17:05:49 +08:00
lixinyue
303f3aefef Merge branch 'refs/heads/develop' into add/develop_memory 2026-01-29 16:58:19 +08:00
lixinyue
44fb5e0fd5 遗漏的历史映射 2026-01-29 16:56:50 +08:00
lixinyue11
17a695120a Add/develop memory (#239)
* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射
2026-01-29 16:03:44 +08:00
Mark
6dc716eaf8 Merge pull request #238 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(model)
2026-01-29 16:03:34 +08:00
lixinyue
194be086d4 遗漏的历史映射 2026-01-29 15:58:11 +08:00
zhaoying
cca3900678 feat(web): BodyWrapper compoent update PageLoading 2026-01-29 15:57:59 +08:00
zhaoying
4fe32b7dbc refactor: The PageScrollList component supports two generic parameters 2026-01-29 15:57:52 +08:00
lixinyue
c49603c25b Merge branch 'refs/heads/develop' into add/develop_memory 2026-01-29 15:53:31 +08:00
lixinyue
8de85a4041 遗漏的历史映射 2026-01-29 15:52:32 +08:00
lixinyue
58a2135fa4 遗漏的历史映射 2026-01-29 15:33:37 +08:00
Timebomb2018
ab9a97db22 fix(model): bug fix 2026-01-29 15:25:25 +08:00
Timebomb2018
d291c241d5 fix(model): the model type does not allow modification, delete tts and speech2text type 2026-01-29 15:21:06 +08:00
yingzhao
24d4cb9b94 Merge pull request #237 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): model bugfix
2026-01-29 14:59:05 +08:00
zhaoying
5b9adb799f fix(web): model bugfix 2026-01-29 14:51:27 +08:00
Mark
38b41df36b [fix] api 2026-01-29 14:41:45 +08:00
Mark
34a9befe5c Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop 2026-01-29 14:03:29 +08:00
Mark
67fd579074 [fix] api 2026-01-29 14:03:21 +08:00
Mark
e2714b942d [add]migration script 2026-01-29 13:54:38 +08:00
Mark
6b2556f870 Merge pull request #236 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(model)
2026-01-29 13:51:14 +08:00
Timebomb2018
43e6e9d201 fix(model): bug fix 2026-01-29 12:33:40 +08:00
yingzhao
131e0cc4c7 Merge pull request #235 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): model list remove is_active
2026-01-29 12:18:33 +08:00
zhaoying
537be81b8f fix(web): model list remove is_active 2026-01-29 12:16:45 +08:00
yingzhao
765168db7f Merge pull request #233 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): model bugfix
2026-01-29 12:11:17 +08:00
zhaoying
1e16b06a24 fix(web): model bugfix 2026-01-29 12:10:19 +08:00
Mark
cd4c93a5cb [fix] web search set for v1 api 2026-01-29 11:52:59 +08:00
Mark
808961243d [fix] chat api for workflow 2026-01-29 11:47:39 +08:00
lixinyue11
4d80e119f7 提交遗漏 (#228) 2026-01-29 10:13:55 +08:00
yingzhao
10c87edae1 Merge pull request #230 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): model bugfix
2026-01-28 20:00:25 +08:00
zhaoying
0eb335d112 fix(web): model bugfix 2026-01-28 19:58:33 +08:00
yingzhao
b8b26ccfe5 Merge pull request #229 from SuanmoSuanyangTechnology/feature/model_zy
fix(web): model bugfix
2026-01-28 18:46:27 +08:00
zhaoying
e89c23da4d fix(web): model bugfix 2026-01-28 18:41:56 +08:00
zhaoying
f3da8956d9 feat(web): add prompt menu 2026-01-28 17:50:09 +08:00
Eternity
b1147d77af feat(prompt): add history tracking for prompt releases 2026-01-28 17:47:44 +08:00
zhaoying
66bc2fb41f feat(web): add PageTabs component 2026-01-28 16:41:13 +08:00
zhaoying
4e538a6df8 feat(web): add PageEmpty component 2026-01-28 16:41:04 +08:00
Mark
ced087f8ae Merge pull request #225 from SuanmoSuanyangTechnology/fix/memory_bug_fix
Fix/memory bug fix
2026-01-28 16:10:58 +08:00
lixinyue
0f1eed0b1e 旧数据兼容 2026-01-28 16:07:53 +08:00
lixinyue
95f15b77a3 旧数据兼容 2026-01-28 16:05:54 +08:00
lixinyue
f9ccfd5ca0 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-28 16:05:46 +08:00
lixinyue
7207d7c847 旧数据兼容 2026-01-28 16:05:35 +08:00
lixinyue
00c4a524b7 旧数据兼容 2026-01-28 16:04:38 +08:00
zhaoying
9c3e0b5541 feat(web): add PageTabs component 2026-01-28 16:02:27 +08:00
zhaoying
33bfe33eb3 feat(web): add PageEmpty component 2026-01-28 16:02:18 +08:00
Mark
3127c382a4 Merge pull request #219 from SuanmoSuanyangTechnology/fix/workflow-stream
fix(workflow): fix streaming output issues with multi-output End nodes
2026-01-28 15:32:48 +08:00
Eternity
1748a390ec perf(workflow): make memory write node backward-compatible and defer config validation 2026-01-28 15:30:36 +08:00
Mark
a7c0837049 Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop 2026-01-28 15:25:11 +08:00
Mark
44bf1eeae2 [add] migrations script 2026-01-28 15:24:55 +08:00
yingzhao
762b7a8ef1 Merge pull request #224 from SuanmoSuanyangTechnology/feature/memoryApi_zy
Revert "feat(web): update read_all_config select valueKey"
2026-01-28 15:22:08 +08:00
zhaoying
102712a16e Revert "feat(web): update read_all_config select valueKey"
This reverts commit 46f0f3cee9.
2026-01-28 15:20:31 +08:00
yingzhao
40810c59d7 Merge pull request #223 from SuanmoSuanyangTechnology/fix/agent_zy
fix(web): agent's knowledge_bases bugfix
2026-01-28 15:06:38 +08:00
zhaoying
35a10e86b5 fix(web): agent's knowledge_bases bugfix 2026-01-28 15:05:12 +08:00
yingzhao
c0c985494d Merge pull request #222 from SuanmoSuanyangTechnology/feature/app_statistics_zy
feat(web): add apps statistics api
2026-01-28 14:53:02 +08:00
zhaoying
8984ba7aef feat(web): add apps statistics api 2026-01-28 14:49:30 +08:00
yingzhao
179869d481 Merge pull request #221 from SuanmoSuanyangTechnology/feature/app_statistics_zy
feat(web): add app statistics
2026-01-28 14:47:32 +08:00
yingzhao
5f29956f2b Merge pull request #213 from SuanmoSuanyangTechnology/feature/model_zy
Feature/model zy
2026-01-28 14:46:09 +08:00
Eternity
dbc4ba84c2 fix(workflow): fix streaming output issues with multi-output End nodes
End nodes with multiple output segments could cause cursor errors or leave some
segments inactive, resulting in incorrect final outputs.
Unified _emit_active_chunks and _update_scope_activate to ensure all segments
are activated in order and streamed correctly.
2026-01-28 13:02:50 +08:00
zhaoying
9e4a527675 feat(web): add app statistics 2026-01-28 11:59:37 +08:00
lixinyue
45833542a7 memory_content暂时不修改 2026-01-28 11:57:17 +08:00
lixinyue
1be6de30d7 memory_content暂时不修改 2026-01-28 11:54:07 +08:00
lixinyue
981d78c8ba 统一字段为config_id_old 2026-01-28 11:47:52 +08:00
lixinyue
fbc7bedb6c 统一字段为config_id_old 2026-01-28 11:45:51 +08:00
lixinyue
4786b0c5d4 统一字段为config_id_old 2026-01-28 11:19:24 +08:00
lixinyue
17bed26096 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-28 11:19:09 +08:00
lixinyue
511e16f1d3 统一字段为config_id_old 2026-01-28 11:18:11 +08:00
zhaoying
18204bc1f7 fix(web): model loading update 2026-01-28 11:11:28 +08:00
lixinyue
b58d97fad3 应用层memory_content->memory_config 2026-01-28 10:59:38 +08:00
lixinyue
d2a67a53b5 应用层memory_content->memory_config 2026-01-28 10:58:46 +08:00
lixinyue
c0b556000c Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-28 10:58:06 +08:00
zhaoying
462c3b0696 fix(web): correct spelling 2026-01-28 10:57:45 +08:00
lixinyue
d34ad73439 应用层memory_content->memory_config 2026-01-28 10:56:41 +08:00
zhaoying
2c21712d58 feat(web): model logo update 2026-01-28 10:50:48 +08:00
zhaoying
ce01e588c9 feat(web): remove file url replace 2026-01-28 09:55:20 +08:00
lixinyue
2a23082203 config_id做映射+1 2026-01-27 21:15:38 +08:00
lixinyue
d373f924f6 config_id做映射+1 2026-01-27 21:10:32 +08:00
lixinyue
eaf46ee006 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-27 21:10:22 +08:00
lixinyue
d51355a0ad config_id做映射+1 2026-01-27 21:09:06 +08:00
zhaoying
1e481a311a feat(web): getModelListUrl add is_active param 2026-01-27 20:33:23 +08:00
lixinyue
46abb23ee8 config_id做映射 2026-01-27 20:24:05 +08:00
lixinyue
8555bb697c Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-27 20:23:57 +08:00
lixinyue
f821893653 config_id做映射 2026-01-27 20:22:14 +08:00
zhaoying
75b3ea1f05 feat(web): update model management 2026-01-27 20:07:53 +08:00
zhaoying
74f0018962 feat(web): add PageTabs component 2026-01-27 19:17:32 +08:00
zhaoying
3a0f07d36f feat(web): add PageEmpty component 2026-01-27 19:17:11 +08:00
lixinyue
a047cf2e91 修复宿主列表获取memory_config_idBUG 2026-01-27 14:32:48 +08:00
lixinyue
a8ae16e321 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-27 14:31:28 +08:00
lixinyue
a53be31765 检查需要更改的格式问题 2026-01-27 11:41:16 +08:00
lixinyue
4475be51cc Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-27 10:27:17 +08:00
lixinyue
d53cbe7868 Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts:
#	api/app/services/memory_storage_service.py
2026-01-26 17:50:05 +08:00
lixinyue
722746c78b user_id->显示为config_id_old传输 2026-01-26 17:47:05 +08:00
lixinyue
e1f5607836 user_id->显示为config_id_old传输 2026-01-26 17:37:40 +08:00
lixinyue
7cd0d78424 user_id->显示为config_id_old传输 2026-01-26 17:21:10 +08:00
lixinyue
d740559749 Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts:
#	api/app/services/memory_storage_service.py
2026-01-26 17:08:55 +08:00
lixinyue
399357f752 user_id->现实为config_id_old 2026-01-26 17:06:55 +08:00
lixinyue
9de6b4f151 感知meta_data字段BUG修复 2026-01-26 11:06:49 +08:00
lixinyue
94cced8323 修复遗留合并BUG 2026-01-23 18:36:33 +08:00
lixinyue
9b8ed16e37 修复遗留合并BUG 2026-01-23 18:35:40 +08:00
lixinyue
a5e44cd229 修复遗留合并BUG 2026-01-23 18:34:13 +08:00
lixinyue
eccc208229 修复遗留合并BUG 2026-01-23 18:34:06 +08:00
lixinyue
79cfabb45d end_user_id清理干净 2026-01-23 17:20:32 +08:00
lixinyue
af6e1e2b99 end_user_id清理干净 2026-01-23 17:20:07 +08:00
lixinyue
4ad51c1b24 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-23 17:15:22 +08:00
lixinyue
c44712167f 解决冲突 2026-01-23 15:03:39 +08:00
lixinyue
1aabaff1f2 解决冲突 2026-01-23 15:00:09 +08:00
lixinyue
21c0383efb Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts:
#	api/app/services/memory_agent_service.py
2026-01-23 14:57:25 +08:00
lixinyue
ebe018347b 检查项目,修复group_id的遗留问题 2026-01-23 10:39:10 +08:00
lixinyue
86fe6fe5ab 检查项目,修复group_id的遗留问题 2026-01-23 10:35:41 +08:00
lixinyue
9e828b1750 config_id字段改成UUID,与develop校对恢复 2026-01-22 21:53:15 +08:00
lixinyue
940d3d4567 config_id字段改成UUID 2026-01-22 20:48:51 +08:00
lixinyue
6bd7b2b8bb Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-22 20:47:23 +08:00
lixinyue
f2d6fd7b08 config_id字段改成UUID 2026-01-22 20:40:41 +08:00
lixinyue
b84c82880c config_id字段改成UUID 2026-01-22 18:45:26 +08:00
lixinyue
fcc418b4a0 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-22 18:44:30 +08:00
lixinyue
15c0bb4c9e Merge remote-tracking branch 'origin/develop' into develop 2026-01-22 18:43:53 +08:00
lixinyue
8db4f914d8 config_config替换成memory_config 2026-01-22 18:43:22 +08:00
lixinyue
f3f9211c9c config_config替换成memory_config 2026-01-22 16:59:40 +08:00
lixinyue
a2a69840f7 config_config替换成memory_config 2026-01-22 16:38:24 +08:00
lanceyq
3a4a7590c2 [fix]Fix the memory interface to use end_user_id. 2026-01-22 16:36:12 +08:00
lixinyue
bcc8b7ce3c config_config替换成memory_config 2026-01-22 16:11:48 +08:00
lixinyue
1c7fe6d134 config_config替换成memory_config 2026-01-22 14:59:01 +08:00
lixinyue
c4039f52bd 把group_id替换end_user_id_ 2026-01-22 12:12:41 +08:00
lixinyue
bd851d5e86 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-22 12:11:43 +08:00
lixinyue
00e448c5d6 Merge remote-tracking branch 'origin/develop' into develop 2026-01-22 12:11:17 +08:00
lixinyue
4aeec8afbf 把group_id替换end_user_id_ 2026-01-21 20:37:39 +08:00
lixinyue
f10432bf3f Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-21 20:35:04 +08:00
lixinyue
f0efed8aa1 把group_id替换end_user_id 2026-01-21 20:33:22 +08:00
lixinyue
4a4931bee2 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 19:37:03 +08:00
lixinyue
afcf12ebc9 Merge remote-tracking branch 'origin/develop' into develop 2026-01-21 19:16:04 +08:00
lixinyue
8f86d3417d Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-21 11:53:52 +08:00
lixinyue
92dfc54c4c Merge remote-tracking branch 'origin/develop' into develop 2026-01-21 11:53:25 +08:00
lixinyue
c93bcb8678 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 11:27:11 +08:00
lixinyue
98b2da9123 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 11:15:18 +08:00
lixinyue
cd5f1a1b28 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 11:05:56 +08:00
lixinyue
0e2e495d09 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 11:03:37 +08:00
lixinyue
84c6c7e2a6 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 2026-01-21 10:36:04 +08:00
lixinyue
c8ebf9c75a Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-20 20:12:53 +08:00
lixinyue
29852ff0a5 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察) 2026-01-20 20:12:14 +08:00
lixinyue
f06ca62589 Merge branch 'refs/heads/fix/memory_bug_fix' into develop 2026-01-20 20:09:29 +08:00
lixinyue
3f39a2be12 Merge remote-tracking branch 'origin/develop' into develop 2026-01-20 20:09:14 +08:00
lixinyue
575190a96d 读取接口内层嵌套BUG修复 2026-01-20 19:14:32 +08:00
lixinyue
78559d98eb 读取接口内层嵌套BUG修复 2026-01-20 19:11:40 +08:00
lixinyue
398964c747 读取接口内层嵌套BUG修复 2026-01-20 18:51:18 +08:00
lixinyue
a634565296 读取接口内层嵌套BUG修复 2026-01-20 18:46:53 +08:00
lixinyue
a5ecbec9a6 读取接口内层嵌套BUG修复 2026-01-20 16:32:52 +08:00
lixinyue
fe79978f88 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-20 16:32:46 +08:00
lixinyue
978ec8bc75 Merge remote-tracking branch 'origin/develop' into develop
# Conflicts:
#	api/app/services/memory_reflection_service.py
2026-01-20 16:32:27 +08:00
lixinyue
6e77f5b068 反思优化测试接口 2026-01-20 11:11:45 +08:00
lixinyue
c9dbb64269 反思优化测试接口 2026-01-20 11:10:10 +08:00
lixinyue
546d32e3eb Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-20 10:47:32 +08:00
lixinyue
616f6401b4 反思优化1.0(优化隐私输出、时间检索) 2026-01-19 18:06:56 +08:00
lixinyue
d047190453 反思优化1.0(优化隐私输出、时间检索) 2026-01-19 18:06:19 +08:00
lixinyue
17504b1b9c Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-19 18:04:29 +08:00
lixinyue
5a0d3df689 反思优化1.0(优化隐私输出、时间检索) 2026-01-19 16:28:01 +08:00
lixinyue
871304c89b 输出数组 2026-01-15 21:48:08 +08:00
lixinyue
8155150e45 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-15 21:47:48 +08:00
lixinyue
d9fb8edaa9 读取的接口,去掉全局锁 2026-01-15 16:47:55 +08:00
lixinyue
dda61679bd Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-15 16:47:37 +08:00
lixinyue
6ac10a8297 Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-15 10:23:35 +08:00
lixinyue
0695c11739 用户详情优化 2026-01-14 18:25:55 +08:00
lixinyue
7a4297c4f1 Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts:
#	api/app/services/user_memory_service.py
2026-01-14 18:25:47 +08:00
lixinyue
2c9e5df27d 用户详情优化 2026-01-14 15:34:45 +08:00
lixinyue
6db37d35ed 用户详情优化 2026-01-14 15:25:04 +08:00
lixinyue
ceee4fe5cf 用户详情优化 2026-01-14 14:54:38 +08:00
lixinyue
130b4a57de Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-14 14:54:33 +08:00
lixinyue
1cee27e830 用户详情优化 2026-01-14 14:51:20 +08:00
lixinyue
ba2ff053f9 用户详情优化 2026-01-14 14:48:37 +08:00
lixinyue
227665439f Merge branch 'refs/heads/develop' into fix/memory_bug_fix 2026-01-14 14:47:15 +08:00
lixinyue
1a2e043ec2 图谱数据量限制数量去掉 2026-01-14 14:27:05 +08:00
lixinyue
89500df0ac 图谱数据量限制数量去掉 2026-01-14 12:20:27 +08:00
lixinyue
cb4e80f1bc 图谱数据量限制数量去掉 2026-01-14 12:15:35 +08:00
264 changed files with 16949 additions and 2808 deletions

View File

@@ -3,9 +3,14 @@ import platform
from datetime import timedelta
from urllib.parse import quote
from app.core.config import settings
from celery import Celery
from app.core.config import settings
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
@@ -63,15 +68,20 @@ celery_app.conf.update(
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
},
)
@@ -79,40 +89,40 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置
beat_schedule_config = {
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,
"args": (),
},
"regenerate-memory-cache": {
"task": "app.tasks.regenerate_memory_cache",
"schedule": memory_cache_regeneration_schedule,
"args": (),
},
"run-forgetting-cycle": {
"task": "app.tasks.run_forgetting_cycle_task",
"schedule": forgetting_cycle_schedule,
"kwargs": {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
}
# beat_schedule_config = {
# "run-workspace-reflection": {
# "task": "app.tasks.workspace_reflection_task",
# "schedule": workspace_reflection_schedule,
# "args": (),
# },
# "regenerate-memory-cache": {
# "task": "app.tasks.regenerate_memory_cache",
# "schedule": memory_cache_regeneration_schedule,
# "args": (),
# },
# "run-forgetting-cycle": {
# "task": "app.tasks.run_forgetting_cycle_task",
# "schedule": forgetting_cycle_schedule,
# "kwargs": {
# "config_id": None, # 使用默认配置,可以通过环境变量配置
# },
# },
# }
# 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
# if settings.DEFAULT_WORKSPACE_ID:
# beat_schedule_config["write-total-memory"] = {
# "task": "app.controllers.memory_storage_controller.search_all",
# "schedule": memory_increment_schedule,
# "kwargs": {
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
# },
# }
celery_app.conf.beat_schedule = beat_schedule_config
# celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -45,6 +45,7 @@ from . import (
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
ontology_controller,
)
# 创建管理端 API 路由器
@@ -90,5 +91,6 @@ manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
__all__ = ["manager_router"]

View File

@@ -7,10 +7,11 @@ Routes:
GET /memory/config/emotion - 获取情绪引擎配置
POST /memory/config/emotion - 更新情绪引擎配置
"""
import uuid
from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional
from typing import Optional, Union
from sqlalchemy.orm import Session
from uuid import UUID
@@ -21,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
from app.services.emotion_config_service import EmotionConfigService
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -37,7 +39,7 @@ class EmotionConfigQuery(BaseModel):
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: UUID = Field(..., description="配置ID")
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -46,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: UUID = Query(..., description="配置ID"),
config_id: UUID|int = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -79,7 +81,7 @@ def get_emotion_config(
f"用户 {current_user.username} 请求获取情绪配置",
extra={"config_id": config_id}
)
config_id=resolve_config_id(config_id, db)
# 初始化服务
config_service = EmotionConfigService(db)
@@ -158,6 +160,7 @@ def update_emotion_config(
}
}
"""
config.config_id=resolve_config_id(config.config_id, db)
try:
api_logger.info(
f"用户 {current_user.username} 请求更新情绪配置",

View File

@@ -34,7 +34,7 @@ from app.schemas.memory_storage_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -84,7 +84,8 @@ async def trigger_forgetting_cycle(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id((config_id), db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
@@ -129,7 +130,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: UUID,
config_id: UUID|int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -158,6 +159,7 @@ async def read_forgetting_config(
)
try:
config_id=resolve_config_id(config_id, db)
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
@@ -195,6 +197,8 @@ async def update_forgetting_config(
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
payload.config_id=resolve_config_id((payload.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
@@ -255,12 +259,10 @@ async def get_forgetting_stats(
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 end_user_id通过它获取 config_id
config_id = None
if end_user_id:
@@ -269,6 +271,7 @@ async def get_forgetting_stats(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
@@ -325,7 +328,7 @@ async def get_forgetting_curve(
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
request.config_id = resolve_config_id((request.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")

View File

@@ -25,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
load_dotenv()
api_logger = get_api_logger()
@@ -43,12 +45,12 @@ async def save_reflection_config(
"""Save reflection configuration to data_comfig table"""
try:
config_id = request.config_id
config_id = resolve_config_id(config_id, db)
if not config_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="缺少必需参数: config_id"
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
memory_config = MemoryConfigRepository.update_reflection_config(
@@ -99,7 +101,7 @@ async def start_workspace_reflection(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
"""启动工作空间中所有匹配应用的反思功能"""
workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
@@ -108,42 +110,55 @@ async def start_workspace_reflection(
service = WorkspaceAppService(db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = []
for data in result['apps_detailed_info']:
if data['memory_configs'] == []:
# 跳过没有配置的应用
if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
continue
releases = data['releases']
memory_configs = data['memory_configs']
end_users = data['end_users']
for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况
print(base['config'])
try:
base_config = int(base['config']) if base['config'] else 0
config_id = int(config['config_id']) if config['config_id'] else 0
except (ValueError, TypeError):
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
# 为每个配置和用户组合执行反思
for config in memory_configs:
config_id_str = str(config['config_id'])
# 找到匹配此配置的所有release
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue
if base_config == config_id and base['app_id'] == user['app_id']:
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": base['app_id'],
"config_id": config['config_id'],
"end_user_id": user['id'],
"reflection_result": reflection_result
})
# 为每个用户执行反思
for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
try:
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": reflection_result
})
except Exception as e:
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": {
"status": "错误",
"message": f"反思失败: {str(e)}"
}
})
return success(data=reflection_results, msg="反思配置成功")
@@ -157,17 +172,20 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: uuid.UUID,
config_id: uuid.UUID|int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询memory_config表中的反思配置信息"""
config_id = resolve_config_id(config_id, db)
try:
config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db)
# 构建返回数据
reflection_config = {
"config_id": result.config_id,
"config_id": memory_config_id,
"reflection_enabled": result.enable_self_reflexion,
"reflection_period_in_hours": result.iteration_period,
"reflexion_range": result.reflexion_range,
@@ -192,7 +210,7 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: UUID,
config_id: UUID|int,
language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@@ -200,7 +218,7 @@ async def reflection_run(
"""Activate the reflection function for all matching applications in the workspace"""
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db)
# 使用MemoryConfigRepository查询反思配置
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result:

View File

@@ -35,6 +35,8 @@ from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# Get API logger
api_logger = get_api_logger()
@@ -141,7 +143,6 @@ def create_config(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
@@ -161,12 +162,12 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: UUID,
config_id: UUID|int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
@@ -188,12 +189,17 @@ def update_config(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)
@@ -211,7 +217,7 @@ def update_config_extracted(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
@@ -233,12 +239,12 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: UUID,
config_id: UUID | int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
@@ -286,6 +292,7 @@ async def pilot_run(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}"
)
payload.config_id = resolve_config_id(payload.config_id, db)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),

View File

@@ -7,7 +7,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
from app.models.user_model import User
from app.repositories.model_repository import ModelConfigRepository
from app.schemas import model_schema
@@ -31,7 +31,12 @@ def get_model_types():
@router.get("/provider", response_model=ApiResponse)
def get_model_providers():
return success(msg="获取模型提供商成功", data=list(ModelProvider))
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
return success(msg="获取模型提供商成功", data=providers)
@router.get("/strategy", response_model=ApiResponse)
def get_model_strategies():
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
@router.get("", response_model=ApiResponse)
@@ -91,7 +96,7 @@ def get_model_list(
@router.get("/new", response_model=ApiResponse)
def get_model_list(
def get_model_list_new(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
@@ -147,7 +152,7 @@ def get_model_plaza_list(
type: Optional[ModelType] = Query(None, description="模型类型"),
provider: Optional[ModelProvider] = Query(None, description="供应商"),
is_official: Optional[bool] = Query(None, description="是否官方模型"),
is_deprecated: Optional[bool] = Query(False, description="是否弃用"),
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -198,6 +203,10 @@ def update_model_base(
):
"""更新基础模型"""
# 不允许更改type类型
if data.type is not None or data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
@@ -318,6 +327,8 @@ async def update_composite_model(
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
try:
if model_data.type is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
@@ -457,11 +468,13 @@ async def create_model_api_key_by_provider(
priority=api_key_data.priority,
model_config_ids=model_config_ids
)
created_keys = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
return success(data=result_list, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
except Exception as e:
api_logger.error(f"创建API Key失败: {str(e)}")
raise

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,611 @@
# -*- coding: utf-8 -*-
"""本体场景和类型路由(续)
由于主Controller文件较大将剩余路由放在此文件中。
"""
from uuid import UUID
from typing import Optional
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.ontology_schemas import (
SceneResponse,
SceneListResponse,
PaginationInfo,
ClassCreateRequest,
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
ClassBatchCreateResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
api_logger = get_api_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
"""获取OntologyService实例不需要LLM
场景和类型管理不需要LLM创建一个dummy配置。
"""
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
return OntologyService(llm_client=llm_client, db=db)
# 这些函数将被导入到主Controller中
async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
当提供 scene_name 参数时,进行模糊搜索(不分页);
当不提供 scene_name 参数时,返回所有场景(支持分页)。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
)
try:
# 确定工作空间ID
if workspace_id:
try:
ws_uuid = UUID(workspace_id)
except ValueError:
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
else:
ws_uuid = current_user.current_workspace_id
if not ws_uuid:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 根据是否提供 scene_name 决定查询方式
if scene_name and scene_name.strip():
# 验证分页参数(模糊搜索也支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
f"in workspace {ws_uuid}, total={total}"
)
else:
# 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size)
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
# ==================== 本体类型管理接口 ====================
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
# 根据列表长度判断是单个还是批量
count = len(request.classes)
mode = "single" if count == 1 else "batch"
api_logger.info(
f"Class creation ({mode}) requested by user {current_user.id}, "
f"scene_id={request.scene_id}, count={count}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 准备类型数据
classes_data = [
{
"class_name": item.class_name,
"class_description": item.class_description
}
for item in request.classes
]
if count == 1:
# 单个创建
class_data = classes_data[0]
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
class_description=class_data["class_description"],
workspace_id=workspace_id
)
# 构建单个响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
else:
# 批量创建
created_classes, errors = service.create_classes_batch(
scene_id=request.scene_id,
classes=classes_data,
workspace_id=workspace_id
)
# 构建批量响应
items = []
for ontology_class in created_classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassBatchCreateResponse(
total=len(classes_data),
success_count=len(created_classes),
failed_count=len(errors),
items=items,
errors=errors if errors else None
)
api_logger.info(
f"Batch class creation completed: "
f"success={len(created_classes)}, failed={len(errors)}"
)
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
async def update_class_handler(
class_id: str,
request: ClassUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体类型"""
api_logger.info(
f"Class update requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 更新类型
ontology_class = service.update_class(
class_id=class_uuid,
class_name=request.class_name,
class_description=request.class_description,
workspace_id=workspace_id
)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class updated successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
except ValueError as e:
api_logger.warning(f"Validation error in class update: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
async def delete_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体类型"""
api_logger.info(
f"Class deletion requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 删除类型
success_flag = service.delete_class(
class_id=class_uuid,
workspace_id=workspace_id
)
api_logger.info(f"Class deleted successfully: {class_id}")
return success(data={"deleted": success_flag}, msg="类型删除成功")
except ValueError as e:
api_logger.warning(f"Validation error in class deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
async def get_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取单个本体类型"""
api_logger.info(
f"Get class requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取类型会抛出ValueError如果不存在
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class retrieved successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
# 类型不存在或无权限访问
api_logger.warning(f"Validation error in get class: {str(e)}")
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
async def classes_handler(
scene_id: str,
class_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取类型列表(支持模糊搜索和全量查询)
当提供 class_name 参数时,进行模糊搜索;
当不提供 class_name 参数时,返回场景下的所有类型。
Args:
scene_id: 场景ID必填
class_name: 类型名称关键词(可选,支持模糊匹配)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if class_name else "list"
api_logger.info(
f"Class {operation} requested by user {current_user.id}, "
f"keyword={class_name}, scene_id={scene_id}"
)
try:
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取场景信息
scene = service.get_scene_by_id(scene_uuid, workspace_id)
if not scene:
api_logger.warning(f"Scene not found: {scene_id}")
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
# 根据是否提供 class_name 决定查询方式
if class_name and class_name.strip():
# 模糊搜索类型
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
else:
# 获取所有类型
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
# 构建响应
items = []
for ontology_class in classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassListResponse(
total=len(items),
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
items=items
)
if class_name:
api_logger.info(
f"Class search completed: found {len(items)} classes matching '{class_name}' "
f"in scene {scene_id}"
)
else:
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))

View File

@@ -1,5 +1,5 @@
import uuid
import json
import uuid
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.prompt_optimizer_schema import (
PromptOptMessage,
CreateSessionResponse,
SessionHistoryResponse,
SessionMessage,
PromptSaveRequest
)
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
@@ -135,3 +139,109 @@ async def get_prompt_opt(
"X-Accel-Buffering": "no"
}
)
@router.post(
"/releases",
summary="Get prompt optimization",
response_model=ApiResponse
)
def save_prompt(
data: PromptSaveRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Save a prompt release for the current tenant.
Args:
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
db (Session): SQLAlchemy database session, injected via dependency.
current_user: Currently authenticated user object, injected via dependency.
Returns:
ApiResponse: Standard API response containing the saved prompt release info:
- id: UUID of the prompt release
- session_id: associated session
- title: prompt title
- prompt: prompt content
- created_at: timestamp of creation
Raises:
Any database or service exceptions are propagated to the global exception handler.
"""
service = PromptOptimizerService(db)
prompt_info = service.save_prompt(
tenant_id=current_user.tenant_id,
session_id=data.session_id,
title=data.title,
prompt=data.prompt
)
return success(data=prompt_info)
@router.delete(
"/releases/{prompt_id}",
summary="Delete prompt (soft delete)",
response_model=ApiResponse
)
def delete_prompt(
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Soft delete a prompt release.
Args:
prompt_id
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Success message confirming deletion
"""
service = PromptOptimizerService(db)
service.delete_prompt(
tenant_id=current_user.tenant_id,
prompt_id=prompt_id
)
return success(msg="Prompt deleted successfully")
@router.get(
"/releases/list",
summary="Get paginated list of released prompts with optional filter",
response_model=ApiResponse
)
def get_release_list(
page: int = 1,
page_size: int = 20,
keyword: str | None = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve paginated list of released prompts for the current tenant.
Optionally filter by keyword in title.
Args:
page (int): Page number (starting from 1)
page_size (int): Number of items per page (max 100)
keyword (str | None): Optional keyword to filter prompt titles
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains paginated list of prompt releases with metadata
"""
service = PromptOptimizerService(db)
result = service.get_release_list(
tenant_id=current_user.tenant_id,
page=max(1, page),
page_size=min(max(1, page_size), 100),
filter_keyword=keyword
)
return success(data=result)

View File

@@ -235,11 +235,11 @@ async def chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,
@@ -268,11 +268,11 @@ async def chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,

View File

@@ -7,27 +7,21 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
@@ -104,7 +98,7 @@ class LangChainAgent:
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
# "tool_count": len(self.tools)
}
)
@@ -143,100 +137,7 @@ class LangChainAgent:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_save(self,messages,end_user_end,aimessages):
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# end_user_end=f"Term_{end_user_end}"
# print(messages)
# print(aimessages)
# session_id = store.save_session(
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# end_user_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
# messagss_list=[]
# retrieved_content=[]
# for messages in history:
# query = messages.get("Query")
# aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
if storage_type == "rag":
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def chat(
self,
@@ -281,30 +182,6 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -325,17 +202,17 @@ class LangChainAgent:
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
content = msg.content
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
elapsed_time = time.time() - start_time
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -343,7 +220,7 @@ class LangChainAgent:
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
"total_tokens": total_tokens
}
}
@@ -403,25 +280,7 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
@@ -437,7 +296,7 @@ class LangChainAgent:
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content=''
full_content = ''
try:
async for event in self.agent.astream_events(
{"messages": messages},
@@ -474,12 +333,17 @@ class LangChainAgent:
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -157,6 +157,11 @@ class Settings:
if origin.strip()
]
# Language Configuration
# Supported values: "zh" (Chinese), "en" (English)
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

View File

@@ -0,0 +1,238 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
清理后的数据
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
}
if isinstance(data, dict):

View File

@@ -0,0 +1,72 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
Returns:
格式化后的消息列表
"""
result = []
user=[]
ai=[]
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role=="user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def agent_chat_messages(user_content,ai_content):
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -1,22 +1,20 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langchain_core.messages import HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -34,14 +32,6 @@ async def make_write_graph():
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -51,43 +41,63 @@ async def make_write_graph():
yield graph
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
end_user_id = 'new_2025test1103' # 组ID
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j'==node_name:
massages=node_data
massages=massages.get('write_result')['status']
print(massages) # | 更新数据: {node_data}
except Exception as e:
import traceback
traceback.print_exc()
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,57 @@
输入句子:{{sentence}}
历史消息:{{history}}
# 你的角色
你是一个擅长事件聚合与语义判断的专家。
# 你的任务
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False
- 描述的是同一个具体事件或事实
- 存在明显的因果关系、前后发展关系
- 是对同一事件的补充、解释、追问或延展
- 逻辑上属于同一语境下的连续讨论
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
- 话题不同,事件主体不同
- 时间、地点、对象明显不同
- 只是语义相似,但并非同一具体事件
- 无直接事件、因果或逻辑关联
# 输出规则(非常重要)
你必须按照以下JSON格式输出
**如果是同一事件:**
```json
{
"is_same_event": true,
"output": false
}
```
**如果不是同一事件:**
```json
{
"is_same_event": false,
"output": [
{
"role": "user",
"content": "输入句子的内容"
},
{
"role": "assistant",
"content": "对应的回复内容"
}
]
}
```
# JSON Schema
{{json_schema}}
# 注意事项
- 必须严格按照上述格式输出
- output 字段:如果是同一事件返回 false如果不是同一事件返回完整的消息列表
- 消息列表必须包含 role 和 content 字段
- 不要输出任何解释、分析或多余内容

View File

@@ -0,0 +1,186 @@
import json
from typing import Any, List, Dict, Optional
from datetime import datetime, timedelta
def serialize_messages(messages: Any) -> str:
"""
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
Args:
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
Returns:
str: JSON 字符串
"""
if isinstance(messages, str):
return messages
if isinstance(messages, (list, tuple)):
# 检查是否是 LangChain 消息对象列表
serialized_list = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# LangChain 消息对象
serialized_list.append({
'type': msg.type,
'content': msg.content,
'role': getattr(msg, 'role', msg.type)
})
elif isinstance(msg, dict):
serialized_list.append(msg)
else:
serialized_list.append(str(msg))
return json.dumps(serialized_list, ensure_ascii=False)
if isinstance(messages, dict):
return json.dumps(messages, ensure_ascii=False)
# 其他类型转为字符串
return str(messages)
def deserialize_messages(messages_str: str) -> Any:
"""
将 JSON 字符串反序列化为原始格式
Args:
messages_str: JSON 字符串
Returns:
反序列化后的对象list、dict 或 string
"""
if not messages_str:
return []
try:
return json.loads(messages_str)
except (json.JSONDecodeError, TypeError):
return messages_str
def fix_encoding(text: str) -> str:
"""
修复错误编码的文本
Args:
text: 需要修复的文本
Returns:
str: 修复后的文本
"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
"""
格式化会话数据为统一的输出格式
Args:
data: 原始会话数据
include_time: 是否包含时间字段
Returns:
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
"""
result = {
"Query": fix_encoding(data.get('messages', '')),
"Answer": fix_encoding(data.get('aimessages', ''))
}
if include_time:
result["starttime"] = data.get('starttime', '')
return result
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
"""
根据时间范围过滤数据
Args:
items: 包含 starttime 字段的数据列表
minutes: 时间范围(分钟)
Returns:
List[Dict]: 过滤后的数据列表
"""
time_threshold = datetime.now() - timedelta(minutes=minutes)
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
filtered_items = []
for item in items:
starttime = item.get('starttime', '')
if starttime and starttime >= time_threshold_str:
filtered_items.append(item)
return filtered_items
def sort_and_limit_results(items: List[Dict], limit: int = 6,
remove_time: bool = True) -> List[Dict]:
"""
对结果进行排序、限制数量并移除时间字段
Args:
items: 数据列表
limit: 最大返回数量
remove_time: 是否移除 starttime 字段
Returns:
List[Dict]: 处理后的数据列表
"""
# 按时间降序排序(最新的在前)
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 限制数量
result_items = items[:limit]
# 移除 starttime 字段
if remove_time:
for item in result_items:
item.pop('starttime', None)
# 如果结果少于1条返回空列表
if len(result_items) < 1:
return []
return result_items
def generate_session_key(session_id: str, key_type: str = "session") -> str:
"""
生成 Redis key
Args:
session_id: 会话ID
key_type: key 类型 ("session", "read", "write", "count")
Returns:
str: Redis key
"""
if key_type == "count":
return f"session:count:{session_id}"
elif key_type == "write":
return f"session:write:{session_id}"
elif key_type == "session" or key_type == "read":
return f"session:{session_id}"
else:
return f"session:{session_id}"
def get_current_timestamp() -> str:
"""
获取当前时间戳字符串
Returns:
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
"""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -1,11 +1,36 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
filter_by_time_range,
sort_and_limit_results,
generate_session_key,
get_current_timestamp
)
class RedisSessionStore:
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
@@ -16,32 +41,437 @@ class RedisSessionStore:
)
self.uudi = session_id
def _fix_encoding(self, text):
"""修复错误编码的文本"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
def save_session_write(self, userid: str, messages: str) -> str:
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
Args:
userid: 用户ID
messages: 用户消息
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
messages = serialize_messages(messages)
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="write")
# 使用 pipeline 批量写入,减少网络往返
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"messages": messages,
"starttime": get_current_timestamp()
})
result = pipe.execute()
# 直接写入数据decode_responses=True 已经处理了编码
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
"""
通过 save_session_write 的 userid 获取 sessionid 和 messages
Args:
userid: 用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
results.append({
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
Args:
end_user_id: 终端用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
返回格式:
[
{
"session_id": "uuid",
"id": "...",
"sessionid": "end_user_id",
"messages": "...",
"starttime": "timestamp"
},
...
]
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 end_user_id 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
"id": data.get('id', ''),
"sessionid": data.get('sessionid', ''),
"messages": fix_encoding(data.get('messages', '')),
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
return False
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
Args:
userid: 用户ID (对应 sessionid 字段)
minutes: 查询最近几分钟的数据默认5分钟
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
matched_items = []
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
matched_items.append({
"Query": fix_encoding(data.get('messages', '')),
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
def delete_all_write_sessions(self) -> int:
"""
删除所有 write 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:write:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
保存用户访问次数统计
Args:
end_user_id: 终端用户ID
count: 访问次数
messages: 消息内容
Returns:
str: 新生成的 session_id
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
"""
通过 end_user_id 查询访问次数统计
Args:
end_user_id: 终端用户ID
Returns:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计(优化版:使用索引)
Args:
end_user_id: 终端用户ID
new_count: 新的 count 值
messages: 消息内容
Returns:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
"""
删除所有 count 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:count:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
Args:
userid: 用户ID
messages: 用户消息
aimessages: AI回复消息
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="read")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
@@ -49,177 +479,195 @@ class RedisSessionStore:
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
"starttime": get_current_timestamp()
})
# 可选设置过期时间例如30天避免数据无限增长
# pipe.expire(key, 30 * 24 * 60 * 60)
# 执行批量操作
result = pipe.execute()
print(f"保存结果: {result[0]}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"保存会话失败: {e}")
print(f"[save_session] 保存会话失败: {e}")
raise e
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
session_ids = []
pipe = self.r.pipeline()
for session in sessions_data:
session_id = str(uuid.uuid4())
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}"
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
})
session_ids.append(session_id)
# 一次性执行所有写入操作
results = pipe.execute()
print(f"批量保存完成: {len(session_ids)} 条记录")
return session_ids
except Exception as e:
print(f"批量保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
Args:
session_id: 会话ID
Returns:
Dict 或 None: 会话数据
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key in self.r.keys('session:*'):
data = self.r.hgetall(key)
if not data:
continue
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
获取所有会话数据(不包括 count 和 write 类型)
Returns:
Dict: 所有会话数据key 为 session_id
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
# 排除 count 和 write 类型的 key
if ':count:' not in key and ':write:' not in key:
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
Args:
sessionid: 会话ID支持模糊匹配
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
优化版本:使用 pipeline 减少网络往返
Args:
session_id: 会话ID
field: 字段名
value: 字段值
Returns:
bool: 是否更新成功
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
pipe = self.r.pipeline()
pipe.exists(key)
pipe.hset(key, field, value)
results = pipe.execute()
return bool(results[0]) # 返回 key 是否存在
return bool(results[0])
# ---------------- 删除 ----------------
def delete_session(self, session_id):
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
Args:
session_id: 会话ID
Returns:
int: 删除的数量
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
return self.r.delete(key)
def delete_all_sessions(self):
def delete_all_sessions(self) -> int:
"""
删除所有会话
删除所有会话(不包括 count 和 write 类型)
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
# 过滤掉 count 和 write 类型
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
if keys_to_delete:
return self.r.delete(*keys_to_delete)
return 0
def delete_duplicate_sessions(self):
def delete_duplicate_sessions(self) -> int:
"""
删除重复会话数据,条件:
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
删除重复会话数据(不包括 count 和 write 类型)
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
Returns:
int: 删除的数量
"""
import time
start_time = time.time()
# 第一步:使用 pipeline 批量获取所有 key
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
return 0
# 第二步:使用 pipeline 批量获取所有数据
# 批量获取所有数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 第三步:在内存中识别重复数据
seen = {} # 用字典记录identifier -> key保留第一个出现的 key
keys_to_delete = [] # 需要删除的 key 列表
# 识别重复数据
seen = {}
keys_to_delete = []
for key, data in zip(keys, all_data, strict=False):
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
if not data:
continue
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
identifier = (
data.get('sessionid', ''),
data.get('id', ''),
data.get('end_user_id', ''),
data.get('messages', ''),
data.get('aimessages', '')
)
if identifier in seen:
# 重复,标记为待删除
keys_to_delete.append(key)
else:
# 第一次出现,记录
seen[identifier] = key
# 第四步:使用 pipeline 批量删除重复的 key
# 批量删除重复的 key
deleted_count = 0
if keys_to_delete:
# 分批删除,避免单次操作过大
batch_size = 1000
for i in range(0, len(keys_to_delete), batch_size):
batch = keys_to_delete[i:i + batch_size]
@@ -233,79 +681,28 @@ class RedisSessionStore:
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count
def find_user_session(self, sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
# 使用 pipeline 批量获取数据,提高性能
keys = self.r.keys('session:*')
if not keys:
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 使用 pipeline 批量获取所有 hash 数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 解析并筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({
"Query": self._fix_encoding(data.get('messages')),
"Answer": self._fix_encoding(data.get('aimessages')),
"starttime": data.get('starttime', '')
})
# 按时间降序排序(最新的在前)
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 只保留最新的6条
result_items = matched_items[:6]
# # 移除 starttime 字段
for item in result_items:
item.pop('starttime', None)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
elapsed_time = time.time() - start_time
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# 全局实例
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
)
write_store = RedisWriteStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
count_store = RedisCountStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import asyncio
import time
from datetime import datetime
@@ -123,23 +124,48 @@ async def write(
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
# Ontology models
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
# Variable configuration models
from app.core.memory.models.variate_config import (
StatementExtractionConfig,
@@ -105,6 +111,9 @@ __all__ = [
"Entity",
"Triplet",
"TripletExtractionResponse",
# Ontology models
"OntologyClass",
"OntologyExtractionResponse",
# Variable configuration
"StatementExtractionConfig",
"ForgettingEngineConfig",

View File

@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -0,0 +1,135 @@
"""Models for ontology classes and extraction responses.
This module contains Pydantic models for representing extracted ontology classes
from scenario descriptions, following OWL ontology engineering standards.
Classes:
OntologyClass: Represents an extracted ontology class
OntologyExtractionResponse: Response model containing extracted ontology classes
"""
from typing import List, Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
class OntologyClass(BaseModel):
"""Represents an extracted ontology class from scenario description.
An ontology class represents an abstract category or concept in a domain,
following OWL ontology engineering standards and naming conventions.
Attributes:
id: Unique string identifier for the ontology class
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
description: Textual description of the class
examples: List of concrete instance examples of this class
parent_class: Optional name of the parent class in the hierarchy
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
id: str = Field(
default_factory=lambda: uuid4().hex,
description="Unique identifier for the ontology class"
)
name: str = Field(
...,
description="Name of the class in PascalCase format"
)
name_chinese: Optional[str] = Field(
None,
description="Chinese translation of the class name"
)
description: str = Field(
...,
description="Description of the class"
)
examples: List[str] = Field(
default_factory=list,
description="List of concrete instance examples"
)
parent_class: Optional[str] = Field(
None,
description="Name of the parent class in the hierarchy"
)
entity_type: str = Field(
...,
description="Type/category of the entity"
)
domain: str = Field(
...,
description="Domain this class belongs to"
)
@field_validator('name')
@classmethod
def validate_pascal_case(cls, v: str) -> str:
"""Validate that the class name follows PascalCase convention.
PascalCase rules:
- Must start with an uppercase letter
- Cannot contain spaces
- Should not contain special characters except underscores
Args:
v: The class name to validate
Returns:
The validated class name
Raises:
ValueError: If the name doesn't follow PascalCase convention
"""
if not v:
raise ValueError("Class name cannot be empty")
if not v[0].isupper():
raise ValueError(
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
)
if ' ' in v:
raise ValueError(
f"Class name '{v}' cannot contain spaces (PascalCase)"
)
# Check for invalid characters (allow alphanumeric and underscore only)
if not all(c.isalnum() or c == '_' for c in v):
raise ValueError(
f"Class name '{v}' contains invalid characters. "
"Only alphanumeric characters and underscores are allowed"
)
return v
class OntologyExtractionResponse(BaseModel):
"""Response model for ontology extraction from LLM.
This model represents the structured output from the LLM when
extracting ontology classes from scenario descriptions.
Attributes:
classes: List of extracted ontology classes
domain: Domain/field the scenario belongs to
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
classes: List[OntologyClass] = Field(
default_factory=list,
description="List of extracted ontology classes"
)
domain: str = Field(
...,
description="Domain/field the scenario belongs to"
)

View File

@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
if len(desc_b) > len(desc_a):
canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or ""
fact_b = getattr(ent, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]:
sources: List[str] = []
if not txt:
return sources
for line in str(txt).splitlines():
ln = line.strip()
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = getattr(canonical, "fact_summary", "") or ""
# fact_b = getattr(ent, "fact_summary", "") or ""
# def _extract_sources(txt: str) -> List[str]:
# sources: List[str] = []
# if not txt:
# return sources
# for line in str(txt).splitlines():
# ln = line.strip()
# 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln)
if m:
content = m.group(1).strip()
if content:
sources.append(content)
# m = re.match(r"^来源[:]\s*(.+)$", ln)
# if m:
# content = m.group(1).strip()
# if content:
# sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip():
sources.append(txt.strip())
return sources
# if not sources and txt.strip():
# sources.append(txt.strip())
# return sources
try:
src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b)
seen = set()
merged_sources: List[str] = []
for s in src_a + src_b:
if s and s not in seen:
seen.add(s)
merged_sources.append(s)
if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a:
canonical.fact_summary = fact_b
# src_a = _extract_sources(fact_a)
# src_b = _extract_sources(fact_b)
# seen = set()
# merged_sources: List[str] = []
# for s in src_a + src_b:
# if s and s not in seen:
# seen.add(s)
# merged_sources.append(s)
# if merged_sources:
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
# elif fact_b and not fact_a:
# canonical.fact_summary = fact_b
pass
except Exception:
# 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b
# if len(fact_b) > len(fact_a):
# canonical.fact_summary = fact_b
pass
except Exception:
pass

View File

@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "")
fact_b = (getattr(b, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a)
score_b = len(desc_b) + len(fact_b)
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = (getattr(a, "fact_summary", "") or "")
# fact_b = (getattr(b, "fact_summary", "") or "")
# score_a = len(desc_a) + len(fact_a)
# score_b = len(desc_b) + len(fact_b)
score_a = len(desc_a)
score_b = len(desc_b)
if score_a != score_b:
return 0 if score_a >= score_b else 1
return 0
@@ -189,7 +192,8 @@ async def _judge_pair(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -197,7 +201,8 @@ async def _judge_pair(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
prompt = render_entity_dedup_prompt(

View File

@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
description=row.get("description") or "",
aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "",
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "",
)

View File

@@ -1085,7 +1085,8 @@ class ExtractionOrchestrator:
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),

View File

@@ -8,4 +8,5 @@
- TemporalExtractor: 时间信息提取
- EmbeddingGenerator: 嵌入向量生成
- MemorySummaryGenerator: 记忆摘要生成
- OntologyExtractor: 本体类提取
"""

View File

@@ -14,6 +14,34 @@ from pydantic import Field
logger = get_memory_logger(__name__)
# 支持的语言列表和默认回退值
SUPPORTED_LANGUAGES = {"zh", "en"}
FALLBACK_LANGUAGE = "en"
def validate_language(language: Optional[str]) -> str:
"""
校验语言参数,确保其为有效值。
Args:
language: 待校验的语言代码
Returns:
有效的语言代码("zh""en"
"""
if language is None:
return FALLBACK_LANGUAGE
lang = str(language).lower().strip()
if lang in SUPPORTED_LANGUAGES:
return lang
logger.warning(
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'"
f"支持的语言: {SUPPORTED_LANGUAGES}"
)
return FALLBACK_LANGUAGE
class MemorySummaryResponse(RobustLLMResponse):
"""Structured response for summary generation per chunk.
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
async def generate_title_and_type_for_summary(
content: str,
llm_client
llm_client,
language: str = None
) -> Tuple[str, str]:
"""
为MemorySummary生成标题和类型
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
Args:
content: Summary的内容文本
llm_client: LLM客户端实例
language: 生成标题使用的语言 ("zh" 中文, "en" 英文)如果为None则从配置读取
Returns:
(标题, 类型)元组
"""
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
from app.core.config import settings
# 如果没有指定语言,从配置中读取,并校验有效性
if language is None:
language = settings.DEFAULT_LANGUAGE
language = validate_language(language)
# 定义有效的类型集合
VALID_TYPES = {
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
}
DEFAULT_TYPE = "conversation" # 默认类型
# 根据语言设置默认标题
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
ERROR_TITLE = "错误" if language == "zh" else "Error"
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
try:
if not content:
logger.warning("content为空无法生成标题和类型")
return ("空内容", DEFAULT_TYPE)
logger.warning(f"content为空无法生成标题和类型 (language={language})")
return (DEFAULT_TITLE, DEFAULT_TYPE)
# 1. 渲染Jinja2提示词模板
prompt = await render_episodic_title_and_type_prompt(content)
# 1. 渲染Jinja2提示词模板,传递语言参数
prompt = await render_episodic_title_and_type_prompt(content, language=language)
# 2. 调用LLM生成标题和类型
messages = [
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
json_str = json_str.strip()
result_data = json.loads(json_str)
title = result_data.get("title", "未知标题")
title = result_data.get("title", UNKNOWN_TITLE)
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
# 5. 校验和归一化类型
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
f"已归一化为 '{episodic_type}'"
)
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
return (title, episodic_type)
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {full_response}")
return ("解析失败", DEFAULT_TYPE)
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
except Exception as e:
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
return ("错误", DEFAULT_TYPE)
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
return (ERROR_TITLE, DEFAULT_TYPE)
async def _process_chunk_summary(
dialog: DialogData,
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
return None
try:
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
from app.core.config import settings
language = validate_language(settings.DEFAULT_LANGUAGE)
# Render prompt via Jinja2 for a single chunk
prompt_content = await render_memory_summary_prompt(
chunk_texts=chunk.content,
json_schema=MemorySummaryResponse.model_json_schema(),
max_words=200,
language=language,
)
messages = [
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
try:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
llm_client=llm_client,
language=language
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
except Exception as e:
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
# Continue without title and type

View File

@@ -0,0 +1,482 @@
"""Ontology class extraction from scenario descriptions using LLM.
This module provides the OntologyExtractor class for extracting ontology classes
from natural language scenario descriptions. It uses LLM-driven extraction combined
with two-layer validation (string validation + OWL semantic validation).
Classes:
OntologyExtractor: Extracts ontology classes from scenario descriptions
"""
import asyncio
import logging
import time
from typing import List, Optional
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
logger = logging.getLogger(__name__)
class OntologyExtractor:
"""Extractor for ontology classes from scenario descriptions.
This extractor uses LLM to identify abstract classes and concepts from
natural language scenario descriptions, following OWL ontology engineering
standards. It performs two-layer validation:
1. String validation (naming conventions, reserved words, duplicates)
2. OWL semantic validation (consistency checking, circular inheritance)
Attributes:
llm_client: OpenAI client for LLM calls
validator: String validator for class names and descriptions
owl_validator: OWL validator for semantic validation
"""
def __init__(self, llm_client: OpenAIClient):
"""Initialize the OntologyExtractor.
Args:
llm_client: OpenAIClient instance for LLM processing
"""
self.llm_client = llm_client
self.validator = OntologyValidator()
self.owl_validator = OWLValidator()
logger.info("OntologyExtractor initialized")
async def extract_ontology_classes(
self,
scenario: str,
domain: Optional[str] = None,
max_classes: int = 15,
min_classes: int = 5,
enable_owl_validation: bool = True,
llm_temperature: float = 0.3,
llm_max_tokens: int = 2000,
max_description_length: int = 500,
timeout: Optional[float] = None,
) -> OntologyExtractionResponse:
"""Extract ontology classes from a scenario description.
This is the main extraction method that orchestrates the entire process:
1. Call LLM to extract ontology classes
2. Perform first-layer validation (string validation and cleaning)
3. Perform second-layer validation (OWL semantic validation)
4. Filter invalid classes based on validation errors
5. Return validated ontology classes
Args:
scenario: Natural language scenario description
domain: Optional domain hint (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
min_classes: Minimum number of classes to extract (default: 5)
enable_owl_validation: Whether to enable OWL validation (default: True)
llm_temperature: LLM temperature parameter (default: 0.3)
llm_max_tokens: LLM max tokens parameter (default: 2000)
max_description_length: Maximum description length (default: 500)
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
Returns:
OntologyExtractionResponse containing validated ontology classes
Raises:
ValueError: If scenario is empty or invalid
asyncio.TimeoutError: If extraction times out
Examples:
>>> extractor = OntologyExtractor(llm_client)
>>> response = await extractor.extract_ontology_classes(
... scenario="A hospital manages patient records...",
... domain="Healthcare",
... max_classes=10,
... timeout=30.0
... )
>>> len(response.classes)
7
"""
# Start timing
start_time = time.time()
# Validate input
if not scenario or not scenario.strip():
logger.error("Scenario description is empty")
raise ValueError("Scenario description cannot be empty")
scenario = scenario.strip()
logger.info(
f"Starting ontology extraction - scenario_length={len(scenario)}, "
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
f"timeout={timeout}"
)
try:
# Step 1: Call LLM for extraction with timeout
logger.info("Step 1: Calling LLM for ontology extraction")
llm_start_time = time.time()
if timeout is not None:
# Wrap LLM call with timeout
try:
response = await asyncio.wait_for(
self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
),
timeout=timeout
)
except asyncio.TimeoutError:
llm_duration = time.time() - llm_start_time
logger.error(
f"LLM extraction timed out after {timeout} seconds "
f"(actual duration: {llm_duration:.2f}s)"
)
# Return empty response on timeout
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
else:
# No timeout specified, call directly
response = await self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
)
llm_duration = time.time() - llm_start_time
logger.info(
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
)
# Step 2: First-layer validation (string validation and cleaning)
logger.info("Step 2: Performing first-layer validation (string validation)")
validation_start_time = time.time()
response = self._validate_and_clean(
response=response,
max_description_length=max_description_length,
)
validation_duration = time.time() - validation_start_time
logger.info(
f"After first-layer validation: {len(response.classes)} classes remain "
f"(validation took {validation_duration:.2f}s)"
)
# Check if we have enough classes after first-layer validation
if len(response.classes) < min_classes:
logger.warning(
f"Only {len(response.classes)} classes remain after validation, "
f"which is below minimum of {min_classes}"
)
# Step 3: Second-layer validation (OWL semantic validation)
if enable_owl_validation and response.classes:
logger.info("Step 3: Performing second-layer validation (OWL validation)")
owl_start_time = time.time()
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
classes=response.classes,
)
owl_duration = time.time() - owl_start_time
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
)
# Filter invalid classes based on errors
response = self._filter_invalid_classes(
response=response,
errors=errors,
)
logger.info(
f"After second-layer validation: {len(response.classes)} classes remain"
)
else:
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
else:
if not enable_owl_validation:
logger.info("Step 3: OWL validation disabled, skipping")
else:
logger.info("Step 3: No classes to validate, skipping OWL validation")
# Calculate total duration
total_duration = time.time() - start_time
# Log extraction statistics
logger.info(
f"Ontology extraction completed - "
f"final_class_count={len(response.classes)}, "
f"domain={response.domain}, "
f"total_duration={total_duration:.2f}s, "
f"llm_duration={llm_duration:.2f}s"
)
return response
except asyncio.TimeoutError:
# Re-raise timeout errors
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction timed out after {timeout} seconds "
f"(total duration: {total_duration:.2f}s)",
exc_info=True
)
raise
except Exception as e:
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
exc_info=True
)
# Return empty response on failure
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
async def _call_llm_for_extraction(
self,
scenario: str,
domain: Optional[str],
max_classes: int,
llm_temperature: float,
llm_max_tokens: int,
) -> OntologyExtractionResponse:
"""Call LLM to extract ontology classes from scenario.
This method renders the extraction prompt using the Jinja2 template
and calls the LLM with structured output to get ontology classes.
Args:
scenario: Scenario description text
domain: Optional domain hint
max_classes: Maximum number of classes to extract
llm_temperature: LLM temperature parameter
llm_max_tokens: LLM max tokens parameter
Returns:
OntologyExtractionResponse from LLM
Raises:
Exception: If LLM call fails
"""
try:
# Render prompt using template
prompt_content = await render_ontology_extraction_prompt(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=OntologyExtractionResponse.model_json_schema(),
)
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
# Create messages for LLM
messages = [
{
"role": "system",
"content": (
"You are an expert ontology engineer specializing in knowledge "
"representation and OWL standards. Extract ontology classes from "
"scenario descriptions following the provided instructions. "
"Return valid JSON conforming to the schema."
),
},
{
"role": "user",
"content": prompt_content,
},
]
# Call LLM with structured output
logger.debug(
f"Calling LLM with temperature={llm_temperature}, "
f"max_tokens={llm_max_tokens}"
)
response = await self.llm_client.response_structured(
messages=messages,
response_model=OntologyExtractionResponse,
)
logger.info(
f"LLM extraction successful - extracted {len(response.classes)} classes"
)
return response
except Exception as e:
logger.error(
f"LLM extraction failed: {str(e)}",
exc_info=True
)
raise
def _validate_and_clean(
self,
response: OntologyExtractionResponse,
max_description_length: int,
) -> OntologyExtractionResponse:
"""Perform first-layer validation: string validation and cleaning.
This method validates and cleans the extracted ontology classes:
1. Validate class names (PascalCase, no reserved words)
2. Sanitize invalid class names
3. Truncate long descriptions
4. Remove duplicate classes
Args:
response: OntologyExtractionResponse from LLM
max_description_length: Maximum description length
Returns:
Cleaned OntologyExtractionResponse
"""
if not response.classes:
logger.debug("No classes to validate")
return response
logger.debug(f"Validating {len(response.classes)} classes")
validated_classes = []
for ontology_class in response.classes:
# Validate class name
is_valid, error_msg = self.validator.validate_class_name(
ontology_class.name
)
if not is_valid:
logger.warning(
f"Invalid class name '{ontology_class.name}': {error_msg}"
)
# Attempt to sanitize
sanitized_name = self.validator.sanitize_class_name(
ontology_class.name
)
logger.info(
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
)
# Update class name
ontology_class.name = sanitized_name
# Re-validate sanitized name
is_valid, error_msg = self.validator.validate_class_name(
sanitized_name
)
if not is_valid:
logger.error(
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
"Skipping this class."
)
continue
# Truncate description if too long
if ontology_class.description:
original_length = len(ontology_class.description)
ontology_class.description = self.validator.truncate_description(
ontology_class.description,
max_length=max_description_length,
)
if len(ontology_class.description) < original_length:
logger.debug(
f"Truncated description for '{ontology_class.name}': "
f"{original_length} -> {len(ontology_class.description)} chars"
)
validated_classes.append(ontology_class)
# Remove duplicates (case-insensitive)
original_count = len(validated_classes)
validated_classes = self.validator.remove_duplicates(validated_classes)
if len(validated_classes) < original_count:
logger.info(
f"Removed {original_count - len(validated_classes)} duplicate classes"
)
# Return cleaned response
return OntologyExtractionResponse(
classes=validated_classes,
domain=response.domain,
)
def _filter_invalid_classes(
self,
response: OntologyExtractionResponse,
errors: List[str],
) -> OntologyExtractionResponse:
"""Filter invalid classes based on OWL validation errors.
This method analyzes OWL validation errors and removes classes
that caused validation failures (e.g., circular inheritance,
inconsistencies).
Args:
response: OntologyExtractionResponse to filter
errors: List of error messages from OWL validation
Returns:
Filtered OntologyExtractionResponse
"""
if not errors:
return response
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
# Extract class names mentioned in errors
invalid_class_names = set()
for error in errors:
# Look for class names in error messages
for ontology_class in response.classes:
if ontology_class.name in error:
invalid_class_names.add(ontology_class.name)
logger.debug(
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
)
# Filter out invalid classes
if invalid_class_names:
original_count = len(response.classes)
filtered_classes = [
c for c in response.classes
if c.name not in invalid_class_names
]
logger.info(
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
f"{invalid_class_names}"
)
return OntologyExtractionResponse(
classes=filtered_classes,
domain=response.domain,
)
return response

View File

@@ -25,6 +25,15 @@ class TripletExtractor:
"""
self.llm_client = llm_client
def _get_language(self) -> str:
"""Get the configured language for entity descriptions
Returns:
Language code ("zh" or "en")
"""
from app.core.config import settings
return settings.DEFAULT_LANGUAGE
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
"""Process a single statement and return extracted triplets and entities"""
# Render the prompt using helper function
@@ -40,7 +49,8 @@ class TripletExtractor:
statement=statement.statement,
chunk_content=chunk_content,
json_schema=TripletExtractionResponse.model_json_schema(),
predicate_instructions=PREDICATE_DEFINITIONS
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language()
)
# Create messages for LLM

View File

@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
key=lambda eid: (
_strength_rank(eid),
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
0 # 临时占位
),
reverse=True
)

View File

@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
# Args:
# entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
chunk_content: The content of the chunk to process
json_schema: JSON schema for the expected output format
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
statement=statement,
chunk_content=chunk_content,
json_schema=json_schema,
predicate_instructions=predicate_instructions
predicate_instructions=predicate_instructions,
language=language
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
'statement': 'str',
'chunk_content': 'str',
'json_schema': 'TripletExtractionResponse.schema',
'predicate_instructions': 'PREDICATE_DEFINITIONS'
'predicate_instructions': 'PREDICATE_DEFINITIONS',
'language': language
})
return rendered_prompt
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
chunk_texts: str,
json_schema: dict,
max_words: int = 200,
language: str = "zh",
) -> str:
"""
Renders the memory summary prompt using the memory_summary.jinja2 template.
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
chunk_texts: Concatenated text of conversation chunks
json_schema: JSON schema for the expected output format
max_words: Maximum words for the summary
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string.
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=max_words,
language=language,
)
log_prompt_rendering('memory summary', rendered_prompt)
log_template_rendering('memory_summary.jinja2', {
'chunk_texts_len': len(chunk_texts or ""),
'max_words': max_words,
'json_schema': 'MemorySummaryResponse.schema'
'json_schema': 'MemorySummaryResponse.schema',
'language': language
})
return rendered_prompt
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
return rendered_prompt
async def render_episodic_title_and_type_prompt(content: str) -> str:
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
"""
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
Args:
content: The content of the episodic memory summary to analyze
language: The language to use for title generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("episodic_type_classification.jinja2")
rendered_prompt = template.render(content=content)
rendered_prompt = template.render(content=content, language=language)
# 记录渲染结果到提示日志
log_prompt_rendering('episodic title and type classification', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('episodic_type_classification.jinja2', {
'content_len': len(content) if content else 0
'content_len': len(content) if content else 0,
'language': language
})
return rendered_prompt
async def render_ontology_extraction_prompt(
scenario: str,
domain: str | None = None,
max_classes: int = 15,
json_schema: dict | None = None
) -> str:
"""
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
Args:
scenario: The scenario description text to extract ontology classes from
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
json_schema: JSON schema for the expected output format
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_ontology.jinja2")
rendered_prompt = template.render(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=json_schema
)
# 记录渲染结果到提示日志
log_prompt_rendering('ontology extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_ontology.jinja2', {
'scenario_len': len(scenario) if scenario else 0,
'domain': domain,
'max_classes': max_classes,
'json_schema': 'OntologyExtractionResponse.schema'
})
return rendered_prompt

View File

@@ -9,7 +9,8 @@
- 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B:
@@ -17,7 +18,8 @@
- 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文:

View File

@@ -1,8 +1,19 @@
=== Task ===
Generate a concise title and classify the episodic memory into the most appropriate category.
{% if language == "zh" %}
**重要:请使用中文生成标题和分类。**
{% else %}
**Important: Please generate the title and classification in English.**
{% endif %}
=== Requirements ===
- Extract a clear, concise title (10-20 characters) that captures the core content
{% if language == "zh" %}
- 标题必须使用中文
{% else %}
- Title must be in English
{% endif %}
- Classify into exactly one category based on the primary theme
- Be specific and avoid ambiguity
- Output must be valid JSON conforming to the schema below

View File

@@ -0,0 +1,210 @@
===Task===
Extract ontology classes from the given scenario description following ontology engineering standards.
===Role===
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
===Scenario Description===
{{ scenario }}
{% if domain -%}
===Domain Hint===
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
{%- endif %}
===Extraction Rules===
**1. Abstract Classes, Not Instances:**
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
- Think in terms of "types of things" rather than "specific things"
**2. Naming Convention (PascalCase):**
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
- Use clear, descriptive names in English
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
**3. Domain Relevance:**
- Focus on classes that are central to the scenario's domain
- Prioritize classes that represent key concepts, entities, or relationships
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
**4. Class Quantity:**
- Extract between 5 and {{ max_classes }} classes
- Aim for a balanced set covering the main concepts in the scenario
- Quality over quantity: prefer well-defined classes over exhaustive lists
**5. Clear Descriptions:**
- Provide concise, informative descriptions in Chinese (max 500 characters)
- Describe what the class represents, not specific instances
- Use clear, natural Chinese language that explains the class's role in the domain
**6. Concrete Examples:**
- Provide 2-5 concrete instance examples in Chinese for each class
- Examples should be specific, realistic instances of the class
- Examples help clarify the class's scope and meaning
- Use natural Chinese language for examples
- Example format: ["示例1", "示例2", "示例3"]
**7. Class Hierarchy:**
- Identify parent-child relationships where applicable
- Use the parent_class field to specify inheritance
- Parent class must be one of the extracted classes or a standard OWL class
- Leave parent_class as null for top-level classes
**8. Entity Types:**
- Classify each class with an appropriate entity_type
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
- Choose the most specific type that applies
**9. OWL Reserved Words:**
- Do NOT use OWL reserved words as class names
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
**10. Language Consistency:**
- Extract all class names in English (PascalCase format) for the "name" field
- Provide Chinese translation for class names in the "name_chinese" field
- Descriptions MUST be in Chinese (中文)
- Examples MUST be in Chinese (中文)
- Use clear, natural Chinese language for descriptions and examples
===Examples===
**Example 1 (Healthcare Domain):**
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
Output:
{
"classes": [
{
"name": "Patient",
"name_chinese": "患者",
"description": "在医疗机构接受医疗护理或治疗的人",
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
"parent_class": null,
"entity_type": "Person",
"domain": "Healthcare"
},
{
"name": "MedicalProcedure",
"name_chinese": "医疗程序",
"description": "为医疗诊断或治疗而执行的系统性操作流程",
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
},
{
"name": "Diagnosis",
"name_chinese": "诊断",
"description": "基于症状和检查结果对疾病或状况的识别",
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Healthcare"
},
{
"name": "Doctor",
"name_chinese": "医生",
"description": "诊断和治疗患者的持证医疗专业人员",
"examples": ["全科医生", "外科医生", "心脏病专家"],
"parent_class": null,
"entity_type": "Role",
"domain": "Healthcare"
},
{
"name": "Treatment",
"name_chinese": "治疗",
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
}
],
"domain": "Healthcare",
"namespace": "http://example.org/healthcare#"
}
**Example 2 (Education Domain):**
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
Output:
{
"classes": [
{
"name": "Student",
"name_chinese": "学生",
"description": "在教育机构注册学习的人",
"examples": ["本科生", "研究生", "在职学生"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "Course",
"name_chinese": "课程",
"description": "涵盖特定学科或主题的结构化教育课程",
"examples": ["计算机科学导论", "微积分I", "世界历史"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Professor",
"name_chinese": "教授",
"description": "教授课程并进行研究的学术教师",
"examples": ["助理教授", "副教授", "正教授"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "AcademicProgram",
"name_chinese": "学术项目",
"description": "通向学位或证书的结构化课程体系",
"examples": ["理学学士", "文学硕士", "博士项目"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Assignment",
"name_chinese": "作业",
"description": "分配给学生以评估学习成果的任务或项目",
"examples": ["论文", "习题集", "研究报告", "实验报告"],
"parent_class": null,
"entity_type": "Object",
"domain": "Education"
},
{
"name": "Lecture",
"name_chinese": "讲座",
"description": "由教师进行的教育性演讲或讲座",
"examples": ["入门讲座", "客座讲座", "在线讲座"],
"parent_class": null,
"entity_type": "Event",
"domain": "Education"
}
],
"domain": "Education",
"namespace": "http://example.org/education#"
}
===Output Format===
**JSON Requirements:**
- Use only ASCII double quotes (") for JSON structure
- Never use Chinese quotation marks ("") or Unicode quotes
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- All class names must be in PascalCase format
- All class names must be unique (case-insensitive)
- Extract between 5 and {{ max_classes }} classes
{{ json_schema }}

View File

@@ -5,6 +5,12 @@
===Task===
Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %}
**重要请使用中文生成实体描述description和示例example。**
{% else %}
**Important: Please generate entity descriptions and examples in English.**
{% endif %}
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
**Entity Extraction:**
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
{% if language == "zh" %}
- **实体描述description必须使用中文**
- **示例example必须使用中文**
{% else %}
- **Entity descriptions must be in English**
- **Examples must be in English**
{% endif %}
- **Semantic Memory Classification (is_explicit_memory):**
* Set to `true` if the entity represents **explicit/semantic memory**:
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
@@ -334,9 +347,11 @@ Output:
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
{% if language == "zh" %}
- **语言要求实体描述description和示例example必须使用中文**
{% else %}
- **Language Requirement: Entity descriptions and examples must be in English**
{% endif %}
- Preserve the original language and do not translate
{{ json_schema }}

View File

@@ -5,10 +5,21 @@
=== Task ===
Summarize the provided conversation chunks into a concise Memory summary.
{% if language == "zh" %}
**重要:请使用中文生成摘要内容。**
{% else %}
**Important: Please generate the summary content in English.**
{% endif %}
=== Requirements ===
- Focus on factual statements, user preferences, relationships, and salient temporal context.
- Avoid repetition and filler; be specific.
- Keep it under {{ max_words or 200 }} words.
{% if language == "zh" %}
- 摘要内容必须使用中文
{% else %}
- Summary content must be in English
{% endif %}
- Output must be valid JSON conforming to the schema below.
=== Input ===
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
4. Do not include line breaks within JSON string values
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
The output language should always be the same as the input language.
{% if language == "zh" %}
**语言要求:输出内容必须使用中文。**
{% else %}
**Language Requirement: The output content must be in English.**
{% endif %}
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}

View File

@@ -0,0 +1,10 @@
"""Validation utilities for ontology extraction.
This module provides validation classes for ontology class names,
descriptions, and OWL compliance checking.
"""
from .ontology_validator import OntologyValidator
from .owl_validator import OWLValidator
__all__ = ['OntologyValidator', 'OWLValidator']

View File

@@ -0,0 +1,268 @@
"""String validation for ontology class names and descriptions.
This module provides the OntologyValidator class for validating and sanitizing
ontology class names according to OWL standards and naming conventions.
Classes:
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
"""
import logging
import re
from typing import List, Tuple
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OntologyValidator:
"""Validator for ontology class names and descriptions.
This validator performs string-level validation including:
- PascalCase naming convention validation
- OWL reserved word checking
- Duplicate class name removal
- Description length truncation
Attributes:
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
"""
# OWL reserved words that cannot be used as class names
OWL_RESERVED_WORDS = {
'Thing', 'Nothing', 'Class', 'Property',
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
'Annotation', 'AnnotationProperty', 'Axiom',
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
'Datatype', 'DataRange', 'Literal',
'DeprecatedClass', 'DeprecatedProperty',
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
'BackwardCompatibleWith', 'OntologyProperty',
}
def validate_class_name(self, name: str) -> Tuple[bool, str]:
"""Validate that a class name follows OWL naming conventions.
Validation rules:
1. Must not be empty
2. Must start with an uppercase letter (PascalCase)
3. Cannot contain spaces
4. Can only contain alphanumeric characters and underscores
5. Cannot be an OWL reserved word
Args:
name: The class name to validate
Returns:
Tuple of (is_valid, error_message)
- is_valid: True if the name is valid, False otherwise
- error_message: Empty string if valid, error description if invalid
Examples:
>>> validator = OntologyValidator()
>>> validator.validate_class_name("MedicalProcedure")
(True, "")
>>> validator.validate_class_name("medical procedure")
(False, "Class name 'medical procedure' cannot contain spaces")
>>> validator.validate_class_name("Thing")
(False, "Class name 'Thing' is an OWL reserved word")
"""
logger.debug(f"Validating class name: '{name}'")
# Check if empty
if not name or not name.strip():
error_msg = "Class name cannot be empty"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
name = name.strip()
# Check if it's an OWL reserved word
if name in self.OWL_RESERVED_WORDS:
error_msg = f"Class name '{name}' is an OWL reserved word"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check if starts with uppercase letter
if not name[0].isupper():
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for spaces
if ' ' in name:
error_msg = f"Class name '{name}' cannot contain spaces"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for invalid characters (only alphanumeric and underscore allowed)
if not re.match(r'^[A-Za-z0-9_]+$', name):
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
logger.debug(f"Class name '{name}' is valid")
return True, ""
def sanitize_class_name(self, name: str) -> str:
"""Attempt to sanitize an invalid class name into a valid format.
Sanitization steps:
1. Strip whitespace
2. Remove invalid characters
3. Replace spaces with empty string (PascalCase)
4. Capitalize first letter of each word
5. If result is empty or starts with number, prefix with 'Class'
Args:
name: The class name to sanitize
Returns:
Sanitized class name that should pass validation
Examples:
>>> validator = OntologyValidator()
>>> validator.sanitize_class_name("medical procedure")
'MedicalProcedure'
>>> validator.sanitize_class_name("patient-record")
'PatientRecord'
>>> validator.sanitize_class_name("123invalid")
'Class123Invalid'
"""
logger.debug(f"Sanitizing class name: '{name}'")
if not name or not name.strip():
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
return "UnnamedClass"
# Strip whitespace
name = name.strip()
original_name = name
# Split on spaces, hyphens, and underscores, then capitalize each word
words = re.split(r'[\s\-_]+', name)
# Capitalize first letter of each word and keep rest as is
sanitized_words = []
for word in words:
if word:
# Remove non-alphanumeric characters except underscore
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
if clean_word:
# Capitalize first letter
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
# Join words
sanitized = ''.join(sanitized_words)
# If empty or starts with number, prefix with 'Class'
if not sanitized or sanitized[0].isdigit():
sanitized = 'Class' + sanitized
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
# If it's a reserved word, append 'Class' suffix
if sanitized in self.OWL_RESERVED_WORDS:
sanitized = sanitized + 'Class'
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
return sanitized
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
"""Remove duplicate ontology classes based on case-insensitive name comparison.
When duplicates are found, keeps the first occurrence and discards subsequent ones.
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
Args:
classes: List of OntologyClass objects
Returns:
List of OntologyClass objects with duplicates removed
Examples:
>>> validator = OntologyValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> unique = validator.remove_duplicates(classes)
>>> len(unique)
2
>>> [c.name for c in unique]
['Patient', 'Doctor']
"""
if not classes:
logger.debug("No classes to check for duplicates")
return classes
logger.debug(f"Checking {len(classes)} classes for duplicates")
seen_names = set()
unique_classes = []
duplicates_found = []
for ontology_class in classes:
# Use lowercase for comparison
name_lower = ontology_class.name.lower()
if name_lower not in seen_names:
seen_names.add(name_lower)
unique_classes.append(ontology_class)
else:
duplicates_found.append(ontology_class.name)
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
if duplicates_found:
logger.info(
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
)
else:
logger.debug("No duplicate classes found")
return unique_classes
def truncate_description(self, description: str, max_length: int = 500) -> str:
"""Truncate a description to a maximum length.
If the description exceeds max_length, it will be truncated and
an ellipsis (...) will be appended to indicate truncation.
Args:
description: The description text to truncate
max_length: Maximum allowed length (default: 500)
Returns:
Truncated description string
Examples:
>>> validator = OntologyValidator()
>>> long_desc = "A" * 600
>>> truncated = validator.truncate_description(long_desc, max_length=500)
>>> len(truncated)
500
>>> truncated.endswith("...")
True
"""
if not description:
return ""
if len(description) <= max_length:
return description
# Truncate and add ellipsis
# Reserve 3 characters for "..."
truncate_at = max_length - 3
truncated = description[:truncate_at] + "..."
logger.debug(
f"Truncated description from {len(description)} to {len(truncated)} characters"
)
return truncated

View File

@@ -0,0 +1,585 @@
"""OWL semantic validation for ontology classes using Owlready2.
This module provides the OWLValidator class for validating ontology classes
against OWL standards using the Owlready2 library. It performs semantic
validation including consistency checking, circular inheritance detection,
and OWL file export.
Classes:
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
"""
import logging
from typing import List, Optional, Tuple
from owlready2 import (
World,
Thing,
get_ontology,
sync_reasoner_pellet,
OwlReadyInconsistentOntologyError,
)
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OWLValidator:
"""Validator for OWL semantic validation of ontology classes.
This validator performs semantic-level validation using Owlready2 including:
- Creating OWL classes from ontology class definitions
- Running consistency checking with Pellet reasoner
- Detecting circular inheritance
- Validating Protégé compatibility
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
Attributes:
base_namespace: Base URI for the ontology namespace
"""
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
"""Initialize the OWL validator.
Args:
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
"""
self.base_namespace = base_namespace
def validate_ontology_classes(
self,
classes: List[OntologyClass],
) -> Tuple[bool, List[str], Optional[World]]:
"""Validate extracted ontology classes against OWL standards.
This method creates an OWL ontology from the provided classes using Owlready2,
runs consistency checking with the Pellet reasoner, and detects common issues
like circular inheritance.
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_valid, error_messages, world):
- is_valid: True if ontology is valid and consistent, False otherwise
- error_messages: List of error/warning messages
- world: Owlready2 World object containing the ontology (None if validation failed)
Examples:
>>> validator = OWLValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> is_valid
True
>>> len(errors)
0
"""
if not classes:
return False, ["No classes provided for validation"], None
errors = []
try:
# Create a new world (isolated ontology environment)
world = World()
# Use a proper ontology IRI
# Owlready2 expects the IRI to end with .owl or similar
onto_iri = self.base_namespace.rstrip('#/')
if not onto_iri.endswith('.owl'):
onto_iri = onto_iri + '.owl'
# Create ontology
onto = world.get_ontology(onto_iri)
with onto:
# Dictionary to store created OWL classes for parent reference
owl_classes = {}
# First pass: Create all classes without parent relationships
for ontology_class in classes:
try:
# Create OWL class dynamically using type() with Thing as base
# The key is to NOT set namespace in the dict, let Owlready2 handle it
owl_class = type(
ontology_class.name, # Class name
(Thing,), # Base classes
{} # Class dict (empty, let Owlready2 manage)
)
# Add label (rdfs:label) - include both English and Chinese names
labels = [ontology_class.name]
if ontology_class.name_chinese:
labels.append(ontology_class.name_chinese)
owl_class.label = labels
# Add comment (rdfs:comment) with description
if ontology_class.description:
owl_class.comment = [ontology_class.description]
# Store for parent relationship setup
owl_classes[ontology_class.name] = owl_class
logger.debug(
f"Created OWL class: {ontology_class.name} "
f"(Chinese: {ontology_class.name_chinese}) "
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
)
except Exception as e:
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
# Second pass: Set up parent relationships
for ontology_class in classes:
if ontology_class.parent_class and ontology_class.name in owl_classes:
parent_name = ontology_class.parent_class
# Check if parent exists
if parent_name in owl_classes:
try:
child_class = owl_classes[ontology_class.name]
parent_class = owl_classes[parent_name]
# Set parent by modifying is_a
child_class.is_a = [parent_class]
logger.debug(
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
)
except Exception as e:
error_msg = (
f"Failed to set parent relationship "
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
)
errors.append(error_msg)
logger.warning(error_msg)
else:
warning_msg = (
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
)
errors.append(warning_msg)
logger.warning(warning_msg)
# Check for circular inheritance
for class_name, owl_class in owl_classes.items():
if self._has_circular_inheritance(owl_class):
error_msg = f"Circular inheritance detected for class '{class_name}'"
errors.append(error_msg)
logger.error(error_msg)
# Run consistency checking with Pellet reasoner
try:
logger.info("Running Pellet reasoner for consistency checking...")
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
logger.info("Consistency check passed")
except OwlReadyInconsistentOntologyError as e:
error_msg = f"Ontology is inconsistent: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
return False, errors, world
except Exception as e:
# Reasoner errors are often due to Java not being installed or configured
# Log as warning but don't fail validation - ontology structure is still valid
warning_msg = f"Reasoner check skipped: {str(e)}"
if str(e).strip(): # Only log if there's an actual error message
logger.warning(warning_msg)
else:
logger.warning("Reasoner check skipped: Java may not be installed or configured")
# Continue - ontology structure is valid even without reasoner check
# If we have errors (excluding warnings), validation failed
is_valid = len(errors) == 0
return is_valid, errors, world
except Exception as e:
error_msg = f"OWL validation failed: {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
return False, errors, None
def _has_circular_inheritance(self, owl_class) -> bool:
"""Check if an OWL class has circular inheritance.
Circular inheritance occurs when a class inherits from itself through
a chain of parent relationships (e.g., A -> B -> C -> A).
Args:
owl_class: Owlready2 class object to check
Returns:
True if circular inheritance is detected, False otherwise
"""
visited = set()
current = owl_class
while current:
# Get class IRI or name as identifier
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
if class_id in visited:
# Found a cycle
return True
visited.add(class_id)
# Get parent classes (is_a relationship)
parents = getattr(current, 'is_a', [])
# Filter out Thing and other base classes
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
if not parent_classes:
# No more parents, no cycle
break
# Check first parent (in single inheritance)
current = parent_classes[0] if parent_classes else None
return False
def export_to_owl(
self,
world: World,
output_path: Optional[str] = None,
format: str = "rdfxml",
classes: Optional[List] = None
) -> str:
"""Export ontology to OWL file in specified format.
Supported formats:
- rdfxml: RDF/XML format (default, most compatible)
- turtle: Turtle format (more readable)
- ntriples: N-Triples format (simplest)
- json: JSON format (simplified, human-readable)
Args:
world: Owlready2 World object containing the ontology
output_path: Optional file path to save the ontology (if None, returns string)
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
classes: Optional list of OntologyClass objects (required for json format)
Returns:
String representation of the exported ontology
Raises:
ValueError: If format is not supported
RuntimeError: If export fails
Examples:
>>> validator = OWLValidator()
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
"""
# Validate format
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
if format not in valid_formats:
raise ValueError(
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
)
# JSON format doesn't need OWL processing
if format == "json":
if not classes:
raise ValueError("Classes list is required for JSON format export")
return self._export_to_json(classes)
# For OWL formats, world is required
if not world:
raise ValueError("World object is None. Cannot export ontology.")
# Note: Owlready2 has issues with turtle format export
# We'll handle it specially by converting from rdfxml
use_conversion = (format == "turtle")
try:
# Get all ontologies in the world
ontologies = list(world.ontologies.values())
if not ontologies:
raise RuntimeError("No ontologies found in world")
# Find the ontology with classes (skip anonymous/empty ontologies)
onto = None
for ont in ontologies:
classes_count = len(list(ont.classes()))
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
if classes_count > 0:
onto = ont
break
# If no ontology with classes found, use the last non-anonymous one
if onto is None:
for ont in reversed(ontologies):
if ont.base_iri != "http://anonymous/":
onto = ont
break
# If still no ontology, use the first one
if onto is None:
onto = ontologies[0]
# Log ontology contents for debugging
logger.info(f"Ontology IRI: {onto.base_iri}")
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
# List all classes in the ontology
all_classes = list(onto.classes())
for cls in all_classes:
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
if hasattr(cls, 'label'):
logger.debug(f" Labels: {cls.label}")
if hasattr(cls, 'comment'):
logger.debug(f" Comments: {cls.comment}")
if len(all_classes) == 0:
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
if output_path:
# Save to file
export_format = "rdfxml" if use_conversion else format
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
onto.save(file=output_path, format=export_format)
# Read back the file content to return
with open(output_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
logger.info(f"Successfully exported ontology to {output_path}")
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
else:
# Export to string (save to temporary location and read)
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
tmp_path = tmp.name
try:
export_format = "rdfxml" if use_conversion else format
onto.save(file=tmp_path, format=export_format)
with open(tmp_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception as e:
error_msg = f"Failed to export ontology: {str(e)}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
def _export_to_json(self, classes: List) -> str:
"""Export ontology classes to simplified JSON format.
This format is more compact and easier to parse than OWL XML.
Args:
classes: List of OntologyClass objects
Returns:
JSON string representation (compact format)
"""
import json
result = {
"ontology": {
"namespace": self.base_namespace,
"classes": []
}
}
for cls in classes:
class_data = {
"name": cls.name,
"name_chinese": cls.name_chinese,
"description": cls.description,
"entity_type": cls.entity_type,
"domain": cls.domain,
"parent_class": cls.parent_class,
"examples": cls.examples if hasattr(cls, 'examples') else []
}
result["ontology"]["classes"].append(class_data)
# 使用紧凑格式:无缩进,使用分隔符减少空格
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
def _convert_to_turtle(self, rdfxml_content: str) -> str:
"""Convert RDF/XML content to Turtle format using rdflib.
Args:
rdfxml_content: RDF/XML format content
Returns:
Turtle format content
"""
try:
from rdflib import Graph
# Parse RDF/XML
g = Graph()
g.parse(data=rdfxml_content, format="xml")
# Serialize to Turtle
turtle_content = g.serialize(format="turtle")
# Handle bytes vs string
if isinstance(turtle_content, bytes):
turtle_content = turtle_content.decode('utf-8')
return turtle_content
except ImportError:
logger.warning(
"rdflib is not installed. Cannot convert to Turtle format. "
"Install with: pip install rdflib"
)
return rdfxml_content
except Exception as e:
logger.error(f"Failed to convert to Turtle format: {e}")
return rdfxml_content
def _format_owl_content(self, content: str, format: str) -> str:
"""Format OWL content for better readability.
Args:
content: Raw OWL content string
format: Format type (rdfxml, turtle, ntriples)
Returns:
Formatted OWL content string
"""
if format == "rdfxml":
# Format XML with proper indentation
try:
import xml.dom.minidom as minidom
dom = minidom.parseString(content)
# Pretty print with 2-space indentation
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
# Remove extra blank lines
lines = []
prev_blank = False
for line in formatted.split('\n'):
is_blank = not line.strip()
if not (is_blank and prev_blank): # Skip consecutive blank lines
lines.append(line)
prev_blank = is_blank
formatted = '\n'.join(lines)
return formatted
except Exception as e:
logger.warning(f"Failed to format XML content: {e}")
return content
elif format == "turtle":
# Turtle format is already relatively readable
# Just ensure consistent line endings and not empty
if not content or content.strip() == "":
logger.warning("Turtle content is empty, this may indicate an export issue")
return content.strip() + '\n' if content.strip() else content
elif format == "ntriples":
# N-Triples format is line-based, ensure proper line endings
return content.strip() + '\n' if content.strip() else content
return content
def validate_with_protege_compatibility(
self,
classes: List[OntologyClass]
) -> Tuple[bool, List[str]]:
"""Validate that ontology classes are compatible with Protégé editor.
Protégé compatibility checks:
- Class names are valid OWL identifiers
- No special characters that Protégé cannot handle
- Namespace is properly formatted
- Labels and comments are properly encoded
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_compatible, warnings):
- is_compatible: True if compatible with Protégé, False otherwise
- warnings: List of compatibility warning messages
Examples:
>>> validator = OWLValidator()
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
>>> is_compatible
True
"""
warnings = []
# Check namespace format
if not self.base_namespace.startswith(('http://', 'https://')):
warnings.append(
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
"for Protégé compatibility"
)
if not self.base_namespace.endswith(('#', '/')):
warnings.append(
f"Namespace '{self.base_namespace}' should end with # or / "
"for Protégé compatibility"
)
# Check each class
for ontology_class in classes:
# Check for special characters that might cause issues
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
warnings.append(
f"Class name '{ontology_class.name}' contains special characters "
"that may cause issues in Protégé"
)
# Check description length (Protégé can handle long descriptions but may display poorly)
if ontology_class.description and len(ontology_class.description) > 1000:
warnings.append(
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
"which may display poorly in Protégé"
)
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
if not ontology_class.name.isascii():
warnings.append(
f"Class name '{ontology_class.name}' contains non-ASCII characters "
"which may cause encoding issues in some Protégé versions"
)
# If no warnings, it's compatible
is_compatible = len(warnings) == 0
return is_compatible, warnings

View File

@@ -0,0 +1 @@
"""模型配置脚本模块"""

View File

@@ -0,0 +1,174 @@
provider: bedrock
enabled: true
models:
- name: ai21
type: llm
provider: bedrock
description: AI21 Labs大语言模型completion生成模式256000上下文窗口
is_deprecated: false
is_official: true
tags:
- 大语言模型
logo: bedrock
- name: amazon nova
type: llm
provider: bedrock
description: Amazon Nova大语言模型支持智能体思考、工具调用、流式工具调用、视觉能力300000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
- vision
logo: bedrock
- name: anthropic claude
type: llm
provider: bedrock
description: Anthropic Claude大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
logo: bedrock
- name: cohere
type: llm
provider: bedrock
description: Cohere大语言模型支持智能体思考、工具调用、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
logo: bedrock
- name: deepseek
type: llm
provider: bedrock
description: DeepSeek大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- vision
- tool-call
- stream-tool-call
logo: bedrock
- name: meta
type: llm
provider: bedrock
description: Meta Llama大语言模型支持智能体思考、工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
logo: bedrock
- name: mistral
type: llm
provider: bedrock
description: Mistral AI大语言模型支持智能体思考、工具调用32000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
logo: bedrock
- name: openai
type: llm
provider: bedrock
description: OpenAI大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
logo: bedrock
- name: qwen
type: llm
provider: bedrock
description: Qwen大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
logo: bedrock
- name: amazon.rerank-v1:0
type: rerank
provider: bedrock
description: amazon.rerank-v1:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
tags:
- 重排序模型
logo: bedrock
- name: cohere.rerank-v3-5:0
type: rerank
provider: bedrock
description: cohere.rerank-v3-5:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
tags:
- 重排序模型
logo: bedrock
- name: amazon.nova-2-multimodal-embeddings-v1:0
type: embedding
provider: bedrock
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型支持视觉能力8192上下文窗口
is_deprecated: false
is_official: true
tags:
- 文本嵌入模型
- vision
logo: bedrock
- name: amazon.titan-embed-text-v1
type: embedding
provider: bedrock
description: amazon.titan-embed-text-v1文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
tags:
- 文本嵌入模型
logo: bedrock
- name: amazon.titan-embed-text-v2:0
type: embedding
provider: bedrock
description: amazon.titan-embed-text-v2:0文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-english-v3
type: embedding
provider: bedrock
description: Cohere Embed 3 English文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
tags:
- 文本嵌入模型
logo: bedrock
- name: cohere.embed-multilingual-v3
type: embedding
provider: bedrock
description: Cohere Embed 3 Multilingual文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
tags:
- 文本嵌入模型
logo: bedrock

View File

@@ -0,0 +1,820 @@
provider: dashscope
enabled: true
models:
- name: deepseek-r1-distill-qwen-14b
type: llm
provider: dashscope
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1-distill-qwen-32b
type: llm
provider: dashscope
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-r1
type: llm
provider: dashscope
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.1
type: llm
provider: dashscope
description: DeepSeek-V3.1大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2-exp
type: llm
provider: dashscope
description: DeepSeek-V3.2-exp实验版大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3.2
type: llm
provider: dashscope
description: DeepSeek-V3.2大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: deepseek-v3
type: llm
provider: dashscope
description: DeepSeek-V3大语言模型支持智能体思考64000上下文窗口对话模式支持文本与JSON格式输出
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
logo: dashscope
- name: farui-plus
type: llm
provider: dashscope
description: farui-plus大语言模型支持多工具调用、智能体思考、流式工具调用12288上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: glm-4.7
type: llm
provider: dashscope
description: GLM-4.7大语言模型支持多工具调用、智能体思考、流式工具调用202752超大上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max-latest
type: llm
provider: dashscope
description: qvq-max-latest大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- vision
- agent-thought
- stream-tool-call
logo: dashscope
- name: qvq-max
type: llm
provider: dashscope
description: qvq-max大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- vision
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-coder-turbo-0919
type: llm
provider: dashscope
description: qwen-coder-turbo-0919代码专用大语言模型支持智能体思考131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen-max-latest
type: llm
provider: dashscope
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max-longcontext
type: llm
provider: dashscope
description: qwen-max-longcontext长上下文大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-max
type: llm
provider: dashscope
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-mt-plus
type: llm
provider: dashscope
description: qwen-mt-plus多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-mt-turbo
type: llm
provider: dashscope
description: qwen-mt-turbo轻量化多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 翻译模型
- agent-thought
logo: dashscope
- name: qwen-plus-0112
type: llm
provider: dashscope
description: qwen-plus-0112大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0125
type: llm
provider: dashscope
description: qwen-plus-0125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0723
type: llm
provider: dashscope
description: qwen-plus-0723大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0806
type: llm
provider: dashscope
description: qwen-plus-0806大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-0919
type: llm
provider: dashscope
description: qwen-plus-0919大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1125
type: llm
provider: dashscope
description: qwen-plus-1125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1127
type: llm
provider: dashscope
description: qwen-plus-1127大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-plus-1220
type: llm
provider: dashscope
description: qwen-plus-1220大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen-vl-max
type: chat
provider: dashscope
description: qwen-vl-max多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-0809
type: chat
provider: dashscope
description: qwen-vl-plus-0809多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-02
type: chat
provider: dashscope
description: qwen-vl-plus-2025-01-02多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-2025-01-25
type: chat
provider: dashscope
description: qwen-vl-plus-2025-01-25多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus-latest
type: chat
provider: dashscope
description: qwen-vl-plus-latest多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen-vl-plus
type: chat
provider: dashscope
description: qwen-vl-plus多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen2.5-0.5b-instruct
type: llm
provider: dashscope
description: qwen2.5-0.5b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-14b
type: llm
provider: dashscope
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-instruct-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b-thinking-2507
type: llm
provider: dashscope
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-235b-a22b
type: llm
provider: dashscope
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b-instruct-2507
type: llm
provider: dashscope
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-30b-a3b
type: llm
provider: dashscope
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-32b
type: llm
provider: dashscope
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-4b
type: llm
provider: dashscope
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-8b
type: llm
provider: dashscope
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-coder-30b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-480b-a35b-instruct
type: llm
provider: dashscope
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus-2025-09-23
type: llm
provider: dashscope
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-coder-plus
type: llm
provider: dashscope
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 代码模型
- agent-thought
logo: dashscope
- name: qwen3-max-2025-09-23
type: llm
provider: dashscope
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-2026-01-23
type: llm
provider: dashscope
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-max-preview
type: llm
provider: dashscope
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-max
type: llm
provider: dashscope
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- 联网搜索
logo: dashscope
- name: qwen3-next-80b-a3b-instruct
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-next-80b-a3b-thinking
type: llm
provider: dashscope
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwen3-omni-flash-2025-12-01
type: llm
provider: dashscope
description: qwen3-omni-flash-2025-12-01多模态大语言模型支持视觉、智能体思考、视频、音频能力65536上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
- audio
logo: dashscope
- name: qwen3-vl-235b-a22b-instruct
type: chat
provider: dashscope
description: qwen3-vl-235b-a22b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- video
logo: dashscope
- name: qwen3-vl-235b-a22b-thinking
type: chat
provider: dashscope
description: qwen3-vl-235b-a22b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-instruct
type: chat
provider: dashscope
description: qwen3-vl-30b-a3b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- video
logo: dashscope
- name: qwen3-vl-30b-a3b-thinking
type: chat
provider: dashscope
description: qwen3-vl-30b-a3b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- video
logo: dashscope
- name: qwen3-vl-flash
type: chat
provider: dashscope
description: qwen3-vl-flash多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- video
logo: dashscope
- name: qwen3-vl-plus-2025-09-23
type: chat
provider: dashscope
description: qwen3-vl-plus-2025-09-23多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwen3-vl-plus
type: chat
provider: dashscope
description: qwen3-vl-plus多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- 多模态模型
- vision
- agent-thought
- video
logo: dashscope
- name: qwq-32b
type: llm
provider: dashscope
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus-0305
type: llm
provider: dashscope
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: qwq-plus
type: llm
provider: dashscope
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- stream-tool-call
logo: dashscope
- name: gte-rerank-v2
type: rerank
provider: dashscope
description: gte-rerank-v2重排序模型4000上下文窗口
is_deprecated: false
is_official: true
tags:
- 重排序模型
logo: dashscope
- name: gte-rerank
type: rerank
provider: dashscope
description: gte-rerank重排序模型4000上下文窗口
is_deprecated: false
is_official: true
tags:
- 重排序模型
logo: dashscope
- name: multimodal-embedding-v1
type: embedding
provider: dashscope
description: multimodal-embedding-v1多模态嵌入模型支持视觉能力8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
tags:
- 嵌入模型
- 多模态模型
- vision
logo: dashscope
- name: text-embedding-v1
type: embedding
provider: dashscope
description: text-embedding-v1文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v2
type: embedding
provider: dashscope
description: text-embedding-v2文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v3
type: embedding
provider: dashscope
description: text-embedding-v3文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
- name: text-embedding-v4
type: embedding
provider: dashscope
description: text-embedding-v4文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope

View File

@@ -0,0 +1,143 @@
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
import os
from pathlib import Path
from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
"""从YAML文件加载指定供应商的模型配置"""
config_dir = Path(__file__).parent
config_file = config_dir / f"{provider.value}_models.yaml"
if not config_file.exists():
return []
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# 检查是否需要加载(默认为 true
if not data.get('enabled', True):
return []
return data.get('models', [])
def _disable_yaml_config(provider: ModelProvider) -> None:
"""将YAML文件的enabled标志设置为false"""
config_dir = Path(__file__).parent
config_file = config_dir / f"{provider.value}_models.yaml"
if not config_file.exists():
return
with open(config_file, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
data['enabled'] = False
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
"""
加载模型配置到数据库
Args:
db: 数据库会话
providers: 要加载的供应商列表None表示加载所有
silent: 是否静默模式(不输出详细日志)
Returns:
dict: 加载结果统计 {"success": int, "skipped": int, "failed": int}
"""
result = {"success": 0, "skipped": 0, "failed": 0}
# 确定要加载的供应商
if providers:
target_providers = [ModelProvider(p) if isinstance(p, str) else p for p in providers]
else:
target_providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
for provider in target_providers:
# 从YAML文件加载模型配置
models = _load_yaml_config(provider)
if not models:
if not silent:
print(f"警告: 供应商 '{provider.value}' 暂无预定义模型")
continue
if not silent:
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
# provider_success = 0
for model_data in models:
try:
# 检查模型是否已存在
existing = db.query(ModelBase).filter(
ModelBase.name == model_data["name"],
ModelBase.provider == model_data["provider"]
).first()
if existing:
# 更新现有模型配置
for key, value in model_data.items():
setattr(existing, key, value)
db.commit()
if not silent:
print(f"更新成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
else:
# 创建新模型
model = ModelBase(**model_data)
db.add(model)
db.commit()
if not silent:
print(f"添加成功: {model_data['name']}")
result["success"] += 1
# provider_success += 1
except Exception as e:
db.rollback()
if not silent:
print(f"添加失败: {model_data['name']} - {str(e)}")
result["failed"] += 1
# 如果该供应商的模型全部加载成功将enabled设置为false
# if provider_success == len(models):
_disable_yaml_config(provider)
return result
def load_models_by_provider(db: Session, provider: str) -> dict:
"""
加载指定供应商的模型配置
Args:
db: 数据库会话
provider: 供应商名称字符串或ModelProvider枚举
Returns:
dict: 加载结果统计
"""
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
return load_models(db, providers=[provider_enum])
def get_available_providers() -> list[Callable[[], str]]:
"""获取所有可用的供应商列表从ModelProvider枚举获取排除COMPOSITE"""
return [p.value for p in ModelProvider if p != ModelProvider.COMPOSITE]
def get_models_by_provider(provider: str) -> list[dict]:
"""获取指定供应商的模型配置列表"""
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
return _load_yaml_config(provider_enum)

View File

@@ -0,0 +1,294 @@
provider: openai
enabled: true
models:
- name: chatgpt-4o-latest
type: llm
provider: openai
description: chatgpt-4o-latest大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
provider: openai
description: gpt-3.5-turbo-0125大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-1106
type: llm
provider: openai
description: gpt-3.5-turbo-1106大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-16k
type: llm
provider: openai
description: gpt-3.5-turbo-16k大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-3.5-turbo-instruct
type: llm
provider: openai
description: gpt-3.5-turbo-instruct大语言模型4096上下文窗口文本补全模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
logo: openai
- name: gpt-3.5-turbo
type: llm
provider: openai
description: gpt-3.5-turbo大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-0125-preview
type: llm
provider: openai
description: gpt-4-0125-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-1106-preview
type: llm
provider: openai
description: gpt-4-1106-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo-2024-04-09
type: llm
provider: openai
description: gpt-4-turbo-2024-04-09大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
logo: openai
- name: gpt-4-turbo-preview
type: llm
provider: openai
description: gpt-4-turbo-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
logo: openai
- name: gpt-4-turbo
type: llm
provider: openai
description: gpt-4-turbo大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
logo: openai
- name: o1-preview
type: llm
provider: openai
description: o1-preview大语言模型支持智能体思考128000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
tags:
- 大语言模型
- agent-thought
logo: openai
- name: o1
type: llm
provider: openai
description: o1大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- structured-output
logo: openai
- name: o3-2025-04-16
type: llm
provider: openai
description: o3-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- vision
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini-2025-01-31
type: llm
provider: openai
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
- structured-output
logo: openai
- name: o3-mini
type: llm
provider: openai
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- stream-tool-call
- structured-output
logo: openai
- name: o3-pro-2025-06-10
type: llm
provider: openai
description: o3-pro-2025-06-10大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- vision
- structured-output
logo: openai
- name: o3-pro
type: llm
provider: openai
description: o3-pro大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- vision
- structured-output
logo: openai
- name: o3
type: llm
provider: openai
description: o3大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- vision
- tool-call
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini-2025-04-16
type: llm
provider: openai
description: o4-mini-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- vision
- stream-tool-call
- structured-output
logo: openai
- name: o4-mini
type: llm
provider: openai
description: o4-mini大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
tags:
- 大语言模型
- agent-thought
- tool-call
- vision
- stream-tool-call
- structured-output
logo: openai
- name: text-embedding-3-large
type: embedding
provider: openai
description: text-embedding-3-large文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
tags:
- 文本向量模型
logo: openai
- name: text-embedding-3-small
type: embedding
provider: openai
description: text-embedding-3-small文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
tags:
- 文本向量模型
logo: openai
- name: text-embedding-ada-002
type: embedding
provider: openai
description: text-embedding-ada-002文本向量模型8097上下文窗口最大分块数32
is_deprecated: false
is_official: true
tags:
- 文本向量模型
logo: openai

View File

@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
import logging
logger = logging.getLogger(__name__)
def knowledge_retrieval(
query: str,
@@ -62,7 +64,15 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
# use_graph = config.get("use_graph", "false").lower() == "true"
use_graph_value = config.get("use_graph", False)
if isinstance(use_graph_value, bool):
use_graph = use_graph_value
elif isinstance(use_graph_value, str):
use_graph = use_graph_value.lower() in ("true", "1", "yes")
else:
use_graph = False
file_names_filter = []
if user_ids:
@@ -159,13 +169,29 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
try:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:
# If reranker fails, log warning and continue with original results
logger.warning(
"Reranker failed, falling back to original results",
extra={
"reranker_id": reranker_id,
"query": query,
"doc_count": len(all_results),
"error": str(rerank_error),
},
)
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
try:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
except Exception as graph_error:
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
return all_results
except Exception as e:

View File

@@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.template_renderer import render_template
logger = logging.getLogger(__name__)
@@ -157,12 +156,137 @@ class WorkflowExecutor:
"error": result.get("error"),
}
def _update_end_activate(self, node_id):
def _update_scope_activate(self, scope, status=None):
"""
Update the activation state of all End nodes based on a completed scope (node or variable).
Iterates over all End nodes in `self.end_outputs` and calls
`update_activate` on each, which may:
- Activate variable segments that depend on the completed node/scope.
- Activate the entire End node output if all control conditions are met.
If any End node becomes active and `self.activate_end` is not yet set,
this node will be marked as the currently active End node.
Args:
scope (str): The node ID or scope that has completed execution.
status (str | None): Optional status of the node (used for branch/control nodes).
"""
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(node_id)
self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
def _update_stream_output_status(self, activate, data):
"""
Update the stream output state of End nodes based on workflow state updates.
This method checks which nodes/scopes are activated and propagates
activation to End nodes accordingly.
Args:
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
data (dict): Mapping of node_id -> node runtime data, including outputs.
Behavior:
For each node in `data`:
1. If the node is activated (`activate[node_id]` is True),
retrieve its output status from `runtime_vars`.
2. Call `_update_scope_activate` to propagate the activation
to all relevant End nodes and update `self.activate_end`.
"""
for node_id in data.keys():
if activate.get(node_id):
node_output_status = (
data[node_id]
.get('runtime_vars', {})
.get(node_id)
.get("output")
)
self._update_scope_activate(node_id, status=node_output_status)
async def _emit_active_chunks(
self,
node_outputs: dict,
variables: dict,
force=False
):
"""
Process and yield all currently active output segments for the currently active End node.
This method handles stream-mode output for an End node by iterating through its output segments
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
`force=True`, which allows all segments to be processed regardless of their activation state.
Behavior:
1. Iterates from the current `cursor` position to the end of the outputs list.
2. For each segment:
- If the segment is literal text (`is_variable=False`), append it directly.
- If the segment is a variable (`is_variable=True`), evaluate it using
`evaluate_expression` with the given `node_outputs` and `variables`,
then transform the result with `_trans_output_string`.
3. Yield a stream event of type "message" containing the processed chunk.
4. Move the `cursor` forward after processing each segment.
5. When all segments have been processed, remove this End node from `end_outputs`
and reset `activate_end` to None.
Args:
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
variables (dict): Current runtime variables, used for variable evaluation.
force (bool, default=False): If True, process segments even if `activate=False`.
Yields:
dict: A stream event of type "message" containing the processed chunk.
Notes:
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
- This method only processes the currently active End node (`self.activate_end`).
- Use `force=True` for final emission regardless of activation state.
"""
end_info = self.end_outputs[self.activate_end]
while end_info.cursor < len(end_info.outputs):
final_chunk = ''
current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force:
# Stop processing until this segment becomes active
break
# Literal segment
if not current_segment.is_variable:
final_chunk += current_segment.literal
else:
# Variable segment: evaluate and transform
try:
chunk = evaluate_expression(
current_segment.literal,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
final_chunk += chunk
except ValueError:
# Log failed evaluation but continue streaming
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
if final_chunk:
yield {
"event": "message",
"data": {
"chunk": final_chunk
}
}
# Advance cursor after processing
end_info.cursor += 1
# Remove End node from active tracking if all segments have been processed
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
@staticmethod
def _trans_output_string(content):
if isinstance(content, str):
@@ -218,14 +342,8 @@ class WorkflowExecutor:
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
full_content = ''
for end_info in self.end_outputs.values():
output_template = "".join([output.literal for output in end_info.outputs])
full_content += render_template(
output_template,
result.get("variables", {}),
result.get("runtime_vars", {}),
strict=False
)
for end_id in self.end_outputs.keys():
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
result["messages"].extend(
[
{
@@ -306,7 +424,7 @@ class WorkflowExecutor:
try:
chunk_count = 0
full_content = ''
self._update_scope_activate("sys")
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
@@ -333,9 +451,12 @@ class WorkflowExecutor:
if not end_info or end_info.cursor >= len(end_info.outputs):
continue
current_output = end_info.outputs[end_info.cursor]
if current_output.is_variable and current_output.depends_on_node(node_id):
if current_output.is_variable and current_output.depends_on_scope(node_id):
if data.get("done"):
end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
else:
full_content += data.get("chunk")
yield {
@@ -415,91 +536,53 @@ class WorkflowExecutor:
elif mode == "updates":
# Handle state updates - store final state
for node_id in data.keys():
self._update_end_activate(node_id)
wait = False
state = graph.get_state(config=self.checkpoint_config)
node_outputs = state.values.get("runtime_vars", {})
for _ in data.keys():
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
state = graph.get_state(config=self.checkpoint_config).values
node_outputs = state.get("runtime_vars", {})
variables = state.get("variables", {})
activate = state.get("activate", {})
for _, node_data in data.items():
node_outputs |= node_data.get("runtime_vars", {})
variables |= node_data.get("variables", {})
self._update_stream_output_status(activate, data)
wait = False
while self.activate_end and not wait:
message = ''
logger.info(self.activate_end)
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
while content.activate:
if not content.is_variable:
full_content += content.literal
message += content.literal
else:
try:
chunk = evaluate_expression(
content.literal,
variables={},
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
break
content = end_info.outputs[end_info.cursor]
if end_info.cursor != len(end_info.outputs):
async for msg_event in self._emit_active_chunks(
node_outputs=node_outputs,
variables=variables
):
full_content += msg_event["data"]['chunk']
yield msg_event
if self.activate_end:
wait = True
else:
self.end_outputs.pop(self.activate_end)
self.activate_end = None
for node_id in data.keys():
self._update_end_activate(node_id)
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
self._update_stream_output_status(activate, data)
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
f"- execution_id: {self.execution_id}")
result = graph.get_state(self.checkpoint_config).values
while self.activate_end:
message = ''
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
if not content.is_variable:
message += content.literal
else:
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
try:
chunk = evaluate_expression(
content.literal,
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
self.end_outputs = {
node_id: node_info
for node_id, node_info in self.end_outputs.items()
if node_info.activate
}
if self.end_outputs or self.activate_end:
while self.activate_end:
async for msg_event in self._emit_active_chunks(
node_outputs=node_outputs,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
if self.end_outputs:
force=True
):
full_content += msg_event["data"]['chunk']
yield msg_event
if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
# 计算耗时
end_time = datetime.datetime.now()

View File

@@ -53,114 +53,110 @@ class OutputContent(BaseModel):
)
)
def depends_on_node(self, node_id: str) -> bool:
def depends_on_scope(self, scope: str) -> bool:
"""
Check if this output segment depends on a specific node's variable.
This method examines the `literal` of the output segment to see if it
contains a variable placeholder referencing the given node in the form:
{{ node_id.field_name }}
It uses a regular expression to match the exact node ID, avoiding
false positives from substring matches (e.g., 'node1' should not match 'node10').
Check if this segment depends on a given scope.
Args:
node_id (str): The ID of the node to check for in this segment's variable placeholders.
scope (str): Node ID or special variable prefix (e.g., "sys").
Returns:
bool:
- True if the segment contains a variable referencing the given node.
- False otherwise.
Example:
literal = "{{node1.name}}"
depends_on_node("node1") -> True
depends_on_node("node2") -> False
Usage:
This method is primarily used in stream mode to determine whether
a particular variable output segment should be activated when a
specific upstream node completes execution.
bool: True if this segment references the given scope.
"""
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
pattern = re.compile(variable_pattern)
match = pattern.search(self.literal)
if match:
return True
return False
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
return bool(re.search(pattern, self.literal))
class StreamOutputConfig(BaseModel):
"""
Streaming output configuration for an End node.
This structure controls:
- whether the End node output is globally active
- which upstream branch nodes are responsible for activation
- how each output segment behaves in streaming mode
This configuration describes how the End node output behaves in streaming mode,
including:
- whether output emission is globally activated
- which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated
"""
activate: bool = Field(
...,
description=(
"Global activation state of the End node output.\n"
"If False, no output should be emitted until all control nodes are resolved."
"Global activation flag for the End node output.\n"
"When False, output segments should not be emitted even if available.\n"
"This flag typically becomes True once required control branch conditions "
"are satisfied."
)
)
control_nodes: list[str] = Field(
control_nodes: dict[str, str] = Field(
...,
description=(
"List of upstream branch node IDs that control this End node.\n"
"Each node must signal completion before output becomes active."
"Control branch conditions for this End node output.\n"
"Mapping of `branch_node_id -> expected_branch_label`.\n"
"The End node output becomes globally active when a controlling branch node "
"reports a matching completion status."
)
)
outputs: list[OutputContent] = Field(
...,
description="Ordered list of output segments parsed from the output template."
description=(
"Ordered list of output segments parsed from the output template.\n"
"Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently."
)
)
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates how many output segments have already been emitted."
"Indicates the next output segment index to be emitted.\n"
"Segments with index < cursor are considered already streamed."
)
)
def update_activate(self, node_id):
def update_activate(self, scope: str, status=None):
"""
Update activation state based on an upstream node completion.
Update streaming activation state based on an upstream node or special variable.
This method is typically called when a branch/control node finishes execution.
Args:
scope (str):
Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`.
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
status (optional):
Completion status of the control branch node.
Required when `scope` refers to a control node.
Behavior:
1. If the node is a control node:
- Remove it from `control_nodes`
- If all control nodes are resolved, activate the entire output
1. Control branch nodes:
- If `scope` matches a key in `control_nodes` and `status` matches the expected
branch label, the End node output becomes globally active (`activate = True`).
2. Activate variable output segments that depend on this node:
- If an output segment is a variable
- And its literal references the completed node_id
- Mark that segment as active
2. Variable output segments:
- For each segment that is a variable (`is_variable=True`):
- If the segment literal references `scope`, mark the segment as active.
- This applies both to regular node variables (e.g., "node_id.field")
and special system variables (e.g., "sys.xxx").
Notes:
- This method does not emit output or advance the streaming cursor.
- It only updates activation flags based on upstream events or special variables.
"""
# Case 1: resolve control branch dependency
if node_id in self.control_nodes:
self.control_nodes.remove(node_id)
# All branch constraints resolved → enable output
if not self.control_nodes:
if scope in self.control_nodes.keys():
if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided")
if status == self.control_nodes[scope]:
self.activate = True
# Case 2: activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
and self.outputs[i].depends_on_node(node_id)
and self.outputs[i].depends_on_scope(scope)
):
self.outputs[i].activate = True
@@ -184,11 +180,11 @@ class GraphBuilder:
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
self._analyze_end_node_output()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property
@@ -216,30 +212,53 @@ class GraphBuilder:
except KeyError:
raise RuntimeError(f"Node not found: Id={node_id}")
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
"""Find upstream branch nodes for a given target node in the workflow graph.
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
"""
Recursively find all upstream branch (control) nodes that influence the execution
of the given target node.
This method identifies all upstream control (branch) nodes that can affect
the execution of `target_node`. If `target_node` is reachable from a start
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
This method walks upstream along the workflow graph starting from `target_node`.
It distinguishes between:
- branch nodes (node types listed in `BRANCH_NODES`)
- non-branch nodes (ordinary processing nodes)
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
and non-branch nodes, recursively traversing upstream through non-branch
nodes. If any non-branch upstream path does not lead to a branch node,
the result will indicate that no valid upstream branch node exists.
Traversal rules:
1. For each immediate upstream node:
- If it is a branch node, it is recorded as an affecting control node.
- If it is a non-branch node, the traversal continues recursively upstream.
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
a branch node, the traversal is considered invalid:
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
Special case:
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
it is considered directly reachable from the workflow entry, and therefore
has no controlling branch nodes.
Args:
target_node (str): The identifier of the target node.
target_node (str):
The identifier of the node whose upstream control branches
are to be resolved.
Returns:
tuple[bool, tuple[str]]:
- has_branch (bool): True if all upstream non-branch paths lead to at least
one branch node; False if any path reaches a start node without a branch.
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
tuple[bool, tuple[tuple[str, str]]]:
- has_branch (bool):
True if every upstream path from `target_node` encounters
at least one branch node.
False if any path reaches a start node without a branch.
- branch_nodes (tuple[tuple[str, str]]):
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
"""
source_nodes = [
edge.get("source")
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
@@ -249,11 +268,13 @@ class GraphBuilder:
branch_nodes = []
non_branch_nodes = []
for node_id in source_nodes:
if self.get_node_type(node_id) in BRANCH_NODES:
branch_nodes.append(node_id)
for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
else:
non_branch_nodes.append(node_id)
non_branch_nodes.append(node_info["id"])
has_branch = True
for node_id in non_branch_nodes:
@@ -334,7 +355,7 @@ class GraphBuilder:
activate=not has_branch,
# Branch nodes that control activation of this End node
control_nodes=list(control_nodes),
control_nodes=dict(control_nodes),
# Convert output segments into OutputContent objects
outputs=list(
@@ -362,7 +383,7 @@ class GraphBuilder:
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
activate=True,
control_nodes=[],
control_nodes={},
outputs=list(
[
OutputContent(

View File

@@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig):
description="code content"
)
language: Literal['python3', 'nodejs'] = Field(
language: Literal['python3', 'javascript'] = Field(
...,
description="language"
)

View File

@@ -2,6 +2,7 @@ import base64
import json
import logging
import re
import urllib.parse
from string import Template
from textwrap import dedent
from typing import Any
@@ -14,7 +15,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig
logger = logging.getLogger(__name__)
SCRIPT_TEMPLATE = Template(dedent("""
PYTHON_SCRIPT_TEMPLATE = Template(dedent("""
$code
import json
@@ -32,6 +33,20 @@ result = "<<RESULT>>" + output_json + "<<RESULT>>"
print(result)
"""))
NODEJS_SCRIPT_TEMPLATE = Template(dedent("""
$code
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8'))
// execute main function
var output_obj = main(inputs_obj)
// convert output to json and print
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>$${output_json}<<RESULT>>`
console.log(result)
"""))
class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -83,18 +98,27 @@ class CodeNode(BaseNode):
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
code = urllib.parse.unquote(code, encoding='utf-8')
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")
).decode("utf-8")
final_script = SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
if self.typed_config.language == "python3":
final_script = PYTHON_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
elif self.typed_config.language == 'javascript':
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
else:
raise ValueError(f"Unsupported language: {self.typed_config.language}")
async with httpx.AsyncClient() as client:
response = await client.post(

View File

@@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
...
)
config_id: UUID = Field(
config_id: UUID | int = Field(
...
)

View File

@@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryWriteNodeConfig(**self.config)
self.typed_config: MemoryWriteNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", state)
if not end_user_id:

View File

@@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
@staticmethod
def _get_prompt():
@@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode):
])
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")

View File

@@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
if self.response_metadata:
usage = self.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
@@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode):
response = await llm.ainvoke(messages)
result = response.content.strip()
self.response_metadata = response.response_metadata
if result in category_names:
category = result

View File

@@ -4,16 +4,19 @@
从文件系统加载预定义的工作流模板
"""
import os
from pathlib import Path
from typing import Optional
import yaml
TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
class TemplateLoader:
"""工作流模板加载器"""
def __init__(self, templates_dir: str = "app/templates/workflows"):
def __init__(self, templates_dir: str = TEMPLATE_DIR):
"""初始化模板加载器
Args:
@@ -22,7 +25,7 @@ class TemplateLoader:
self.templates_dir = Path(templates_dir)
if not self.templates_dir.exists():
raise ValueError(f"模板目录不存在: {templates_dir}")
def list_templates(self) -> list[dict]:
"""列出所有可用的模板
@@ -30,22 +33,22 @@ class TemplateLoader:
模板列表,每个模板包含 id, name, description 等信息
"""
templates = []
# 遍历模板目录
for template_dir in self.templates_dir.iterdir():
if not template_dir.is_dir():
continue
# 检查是否有 template.yml 文件
template_file = template_dir / "template.yml"
if not template_file.exists():
continue
try:
# 读取模板配置
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 提取模板信息
templates.append({
"id": template_dir.name,
@@ -59,9 +62,9 @@ class TemplateLoader:
except Exception as e:
print(f"加载模板 {template_dir.name} 失败: {e}")
continue
return templates
def load_template(self, template_id: str) -> Optional[dict]:
"""加载指定的模板
@@ -73,14 +76,14 @@ class TemplateLoader:
"""
template_dir = self.templates_dir / template_id
template_file = template_dir / "template.yml"
if not template_file.exists():
return None
try:
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 返回工作流配置部分
return {
"name": template_data.get("name", template_id),
@@ -94,7 +97,7 @@ class TemplateLoader:
except Exception as e:
print(f"加载模板 {template_id} 失败: {e}")
return None
def get_template_readme(self, template_id: str) -> Optional[str]:
"""获取模板的 README 文档
@@ -106,10 +109,10 @@ class TemplateLoader:
"""
template_dir = self.templates_dir / template_id
readme_file = template_dir / "README.md"
if not readme_file.exists():
return None
try:
with open(readme_file, 'r', encoding='utf-8') as f:
return f.read()

View File

@@ -16,6 +16,8 @@ from app.core.error_codes import BizCode, HTTP_MAPPING
from app.core.exceptions import BusinessException
from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail
from app.core.models.scripts.loader import load_models
from app.db import get_db_context
# Initialize logging system
LoggingConfig.setup_logging()
@@ -47,6 +49,15 @@ async def lifespan(app: FastAPI):
else:
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
# 加载预定义模型
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
logger.info("应用程序启动完成")
yield
# 应用关闭事件

View File

@@ -28,6 +28,10 @@ from .tool_model import (
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
__all__ = [
"Tenants",

View File

@@ -20,6 +20,9 @@ class MemoryConfig(Base):
end_user_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 本体场景关联
scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID关联ontology_scene表")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")

View File

@@ -5,6 +5,7 @@ from enum import StrEnum
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.db import Base
@@ -23,6 +24,8 @@ class ModelType(StrEnum):
CHAT = "chat"
EMBEDDING = "embedding"
RERANK = "rerank"
# TTS = "tts"
# SPEECH2TEXT = "speech2text"
# IMAGE = "image"
# AUDIO = "audio"
# VISION = "vision"
@@ -48,8 +51,7 @@ class ModelProvider(StrEnum):
class LoadBalanceStrategy(StrEnum):
"""API Key负载均衡策略枚举"""
ROUND_ROBIN = "round_robin" # 轮询
WEIGHTED_ROUND_ROBIN = "weighted_round_robin" # 加权轮询
RANDOM = "random" # 随机
NONE = "none" #
# 多对多关联表
@@ -90,7 +92,8 @@ class ModelConfig(BaseModel):
# 状态管理
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略")
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略", default=LoadBalanceStrategy.NONE,
server_default=LoadBalanceStrategy.NONE)
# 关联关系
model_base = relationship("ModelBase", back_populates="configs")
@@ -151,6 +154,7 @@ class ModelBase(Base):
is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)")
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作']")
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
# 关联关系
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")

View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""本体类型模型
本模块定义本体类型的数据模型。
Classes:
OntologyClass: 本体类型表模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class OntologyClass(Base):
"""本体类型表 - 用于存储某个场景提取出来的本体类型信息"""
__tablename__ = "ontology_class"
# 主键
class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID")
# 类型信息
class_name = Column(String(200), nullable=False, comment="类型名称")
class_description = Column(Text, nullable=True, comment="类型描述")
# 外键:关联到本体场景
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
# 关系:类型属于某个场景
scene = relationship("OntologyScene", back_populates="classes")
def __repr__(self):
return f"<OntologyClass(id={self.class_id}, name={self.class_name}, scene_id={self.scene_id})>"

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""本体场景模型
本模块定义本体场景的数据模型。
Classes:
OntologyScene: 本体场景表模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class OntologyScene(Base):
"""本体场景表 - 用于存储本体场景下不同的类型信息"""
__tablename__ = "ontology_scene"
__table_args__ = (
UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'),
)
# 主键
scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID")
# 场景信息
scene_name = Column(String(200), nullable=False, comment="场景名称")
scene_description = Column(Text, nullable=True, comment="场景描述")
# 外键:关联到工作空间
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
# 关系:一个场景可以有多个类型
classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan")
def __repr__(self):
return f"<OntologyScene(id={self.scene_id}, name={self.scene_name})>"

View File

@@ -2,7 +2,7 @@ import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
@@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
session_id = Column(
UUID(as_uuid=True),
ForeignKey("prompt_opt_session_list.id"),
nullable=False,
comment="Session ID"
)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
role = Column(String, nullable=False, comment="Message Role")
content = Column(Text, nullable=False, comment="Message Content")
# prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
class PromptHistory(Base):
__tablename__ = "prompt_history"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
session_id = Column(
UUID(as_uuid=True),
ForeignKey("prompt_opt_session_list.id"),
nullable=False,
comment="Session ID"
)
title = Column(String, nullable=False, comment="Title")
prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
is_delete = Column(Boolean, default=False, comment="Delete")

View File

@@ -24,12 +24,16 @@ from app.schemas.memory_storage_schema import (
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器
db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "memory_config"
class MemoryConfigRepository:
"""记忆配置Repository
@@ -82,7 +86,8 @@ class MemoryConfigRepository:
n.description AS description,
n.entity_type AS entity_type,
n.name AS name,
COALESCE(n.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(n.fact_summary, '') AS fact_summary,
n.end_user_id AS end_user_id,
n.apply_id AS apply_id,
n.user_id AS user_id,
@@ -152,7 +157,7 @@ class MemoryConfigRepository:
return memory_config_obj
@staticmethod
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
@@ -168,6 +173,7 @@ class MemoryConfigRepository:
if not memory_config:
raise RuntimeError("reflection config not found")
return memory_config
@staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
@@ -187,7 +193,6 @@ class MemoryConfigRepository:
raise RuntimeError("reflection config not found")
return memory_config
@staticmethod
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
@@ -227,9 +232,12 @@ class MemoryConfigRepository:
config_name=params.config_name,
config_desc=params.config_desc,
workspace_id=params.workspace_id,
scene_id=params.scene_id,
llm_id=params.llm_id,
embedding_id=params.embedding_id,
rerank_id=params.rerank_id,
reflection_model_id=params.reflection_model_id,
emotion_model_id=params.emotion_model_id,
)
db.add(db_config)
db.flush() # 获取自增ID但不提交事务
@@ -272,6 +280,9 @@ class MemoryConfigRepository:
if update.config_desc is not None:
db_config.config_desc = update.config_desc
has_update = True
if update.scene_id is not None:
db_config.scene_id = update.scene_id
has_update = True
if not has_update:
raise ValueError("No fields to update")
@@ -287,7 +298,6 @@ class MemoryConfigRepository:
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
"""更新记忆萃取引擎配置
@@ -410,7 +420,7 @@ class MemoryConfigRepository:
raise
@staticmethod
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
def get_extracted_config(db: Session, config_id: UUID | int) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
@@ -420,8 +430,8 @@ class MemoryConfigRepository:
Returns:
Optional[Dict]: 萃取配置字典不存在则返回None
"""
config_id = resolve_config_id(config_id, db)
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
if not db_config:
@@ -514,26 +524,28 @@ class MemoryConfigRepository:
except Exception as e:
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
"""Get memory config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
"""
import time
from app.models.workspace_model import Workspace
start_time = time.time()
config_id = resolve_config_id(config_id, db)
# Log configuration loading start
config_logger.info(
"Loading configuration with workspace",
@@ -542,17 +554,17 @@ class MemoryConfigRepository:
"config_id": config_id
}
)
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(MemoryConfig, Workspace).join(
Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(MemoryConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
@@ -581,9 +593,11 @@ class MemoryConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
db_logger.error(
f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(
f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
"Configuration not found",
extra={
@@ -595,9 +609,9 @@ class MemoryConfigRepository:
)
db_logger.debug(f"Memory config not found: config_id={config_id}")
return None
config, workspace = result
# Log successful configuration loading
config_logger.info(
"Configuration with workspace loaded successfully",
@@ -612,16 +626,17 @@ class MemoryConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
# Re-raise known business exceptions
raise
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load configuration with workspace",
extra={
@@ -634,32 +649,37 @@ class MemoryConfigRepository:
},
exc_info=True
)
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
"""获取所有配置参数,包含关联的场景名称
Args:
db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[MemoryConfig]: 配置列表
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
"""
from app.models.ontology_scene import OntologyScene
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(MemoryConfig)
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
)
if workspace_id:
query = query.filter(MemoryConfig.workspace_id == workspace_id)
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
results = query.order_by(desc(MemoryConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
return results
except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")

View File

@@ -165,7 +165,7 @@ class ModelConfigRepository:
total = base_query.count()
# 分页查询
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
models = base_query.order_by(desc(ModelConfig.created_at)).offset(
(query.page - 1) * query.pagesize
).limit(query.pagesize).all()
@@ -234,7 +234,7 @@ class ModelConfigRepository:
# 获取总数
total = base_query.count()
query_results = base_query.order_by(desc(ModelConfig.updated_at)).all()
query_results = base_query.order_by(desc(ModelConfig.created_at)).all()
provider_groups: Dict[str, List[ModelConfig]] = {}
for model_config in query_results:
@@ -433,6 +433,7 @@ class ModelConfigRepository:
ModelConfig.is_public
),
ModelBase.provider == provider,
ModelConfig.is_active,
~ModelConfig.is_composite
)
).distinct().all()
@@ -621,7 +622,7 @@ class ModelBaseRepository:
if filters:
q = q.filter(and_(*filters))
return q.order_by(ModelBase.add_count.desc()).all()
return q.order_by(ModelBase.add_count.desc(), ModelBase.created_at.desc()).all()
@staticmethod
def create(db: Session, data: dict) -> 'ModelBase':
@@ -629,6 +630,13 @@ class ModelBaseRepository:
db.add(model_base)
return model_base
@staticmethod
def get_by_name_and_provider(db: Session, name: str, provider: str) -> Optional['ModelBase']:
return db.query(ModelBase).filter(
ModelBase.name == name,
ModelBase.provider == provider
).first()
@staticmethod
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
@@ -636,6 +644,17 @@ class ModelBaseRepository:
return None
for key, value in data.items():
setattr(model_base, key, value)
# 同步更新绑定的非组合模型配置
if any(k in data for k in ['name', 'description', 'logo']):
db.query(ModelConfig).filter(
ModelConfig.model_id == model_base_id,
ModelConfig.is_composite == False
).update({
k: v for k, v in data.items()
if k in ['name', 'description', 'logo']
}, synchronize_session=False)
return model_base
@staticmethod

View File

@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
try:
edges: List[dict] = []
for s in summaries:
for chunk_id in getattr(s, "chunk_ids", []) or []:
chunk_ids = getattr(s, "chunk_ids", []) or []
for chunk_id in chunk_ids:
edges.append({
"summary_id": s.id,
"chunk_id": chunk_id,
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
if not edges:
return []
result = await connector.execute_query(
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
edges=edges
)
created = [record.get("uuid") for record in result] if result else []
return created
except Exception:
except Exception as e:
return None

View File

@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception:
except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None

View File

@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.name_embedding = CASE
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
ELSE e.name_embedding END,
e.fact_summary = CASE
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
THEN entity.fact_summary ELSE e.fact_summary END,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// e.fact_summary = CASE
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
// THEN entity.fact_summary ELSE e.fact_summary END,
e.connect_strength = CASE
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
ELSE CASE
@@ -321,7 +322,8 @@ RETURN e.id AS id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
COALESCE(e.fact_summary, '') AS fact_summary,
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
// COALESCE(e.fact_summary, '') AS fact_summary,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
@@ -877,7 +879,8 @@ RETURN
CASE
WHEN ms:ExtractedEntity THEN {
text: ms.name,
created_at: ms.created_at
created_at: ms.created_at,
type: "情景记忆"
}
END
) AS ExtractedEntity,
@@ -887,7 +890,8 @@ RETURN
CASE
WHEN n:MemorySummary THEN {
text: n.content,
created_at: n.created_at
created_at: n.created_at,
type: "长期沉淀"
}
END
) AS MemorySummary,
@@ -895,7 +899,8 @@ RETURN
collect(
DISTINCT {
text: e.statement,
created_at: e.created_at
created_at: e.created_at,
type: "情绪记忆"
}
) AS statement;
"""
@@ -999,3 +1004,58 @@ RETURN DISTINCT
x.statement as statement,x.created_at as created_at
"""
Graph_Node_query = """
MATCH (n:MemorySummary)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
0 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Dialogue)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT 1
UNION ALL
MATCH (n:Statement)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
1 AS priority
LIMIT $limit
UNION ALL
MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
2 AS priority
LIMIT $limit
UNION ALL
MATCH (n:Chunk)
WHERE n.end_user_id = $end_user_id
RETURN
elementId(n) AS id,
labels(n) AS labels,
properties(n) AS properties,
3 AS priority
LIMIT $limit
"""

View File

@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode,
EntityEntityEdge,
)
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
}
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -171,40 +172,127 @@ async def save_dialog_and_statements_to_neo4j(
Returns:
bool: True if successful, False otherwise
"""
try:
# Save all dialogue nodes in batch
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
if dialogue_uuids:
# 定义事务函数,将所有写操作放在一个事务中
async def _save_all_in_transaction(tx):
"""在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector)
# 2. Save all chunk nodes in batch
if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
# Save all statement nodes in batch
# 3. Save all statement nodes in batch
if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector)
if statement_uuids:
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("Failed to save statement nodes to Neo4j")
return False
else:
print("No statement nodes to save")
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
statement_data = [node.model_dump() for node in statement_nodes]
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
statement_uuids = [record["uuid"] async for record in result]
results['statements'] = statement_uuids
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
# Save entities and relationships
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
print("Successfully saved entities and relationships to Neo4j")
# 4. Save entities
if entity_nodes:
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# Save new edges
await save_statement_chunk_edges(statement_chunk_edges, connector)
await save_statement_entity_edges(statement_entity_edges, connector)
# 5. Create entity relationships
if entity_edges:
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"),
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
summary = {
key: len(value)
for key, value in results.items()
if isinstance(value, (list, tuple, set))
}
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
return True
except Exception as e:
logger.error(f"Neo4j integration error: {e}", exc_info=True)
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""本体类型Repository层
本模块提供本体类型的数据访问层实现。
Classes:
OntologyClassRepository: 本体类型数据访问类
"""
import logging
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.ontology_class import OntologyClass
from app.models.ontology_scene import OntologyScene
logger = get_db_logger()
class OntologyClassRepository:
"""本体类型Repository
提供本体类型的CRUD操作和权限检查。
Attributes:
db: SQLAlchemy数据库会话
"""
def __init__(self, db: Session):
"""初始化Repository
Args:
db: SQLAlchemy数据库会话
"""
self.db = db
def create(self, class_data: dict, scene_id: UUID) -> OntologyClass:
"""创建本体类型
Args:
class_data: 类型数据字典包含class_name和class_description
scene_id: 所属场景ID
Returns:
OntologyClass: 创建的类型对象
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.create(
... {"class_name": "患者", "class_description": "描述"},
... scene_id
... )
"""
try:
logger.info(
f"Creating ontology class - "
f"name={class_data.get('class_name')}, "
f"scene_id={scene_id}"
)
ontology_class = OntologyClass(
class_name=class_data.get("class_name"),
class_description=class_data.get("class_description"),
scene_id=scene_id
)
self.db.add(ontology_class)
self.db.flush() # 获取ID但不提交
logger.info(
f"Ontology class created successfully - "
f"class_id={ontology_class.class_id}"
)
return ontology_class
except Exception as e:
logger.error(
f"Failed to create ontology class: {str(e)}",
exc_info=True
)
raise
def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]:
"""根据ID获取类型
Args:
class_id: 类型ID
Returns:
Optional[OntologyClass]: 类型对象不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.get_by_id(class_id)
"""
try:
logger.debug(f"Getting ontology class by ID: {class_id}")
ontology_class = self.db.query(OntologyClass).filter(
OntologyClass.class_id == class_id
).first()
if ontology_class:
logger.debug(f"Ontology class found: {class_id}")
else:
logger.debug(f"Ontology class not found: {class_id}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to get ontology class by ID: {str(e)}",
exc_info=True
)
raise
def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]:
"""根据类型名称和场景ID获取类型精确匹配
Args:
class_name: 类型名称
scene_id: 场景ID
Returns:
Optional[OntologyClass]: 类型对象不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.get_by_name("患者", scene_id)
"""
try:
logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}")
ontology_class = self.db.query(OntologyClass).filter(
OntologyClass.class_name == class_name,
OntologyClass.scene_id == scene_id
).first()
if ontology_class:
logger.debug(f"Ontology class found: {class_name}")
else:
logger.debug(f"Ontology class not found: {class_name}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to get ontology class by name: {str(e)}",
exc_info=True
)
raise
def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]:
"""根据关键词模糊搜索类型
使用 LIKE 进行模糊匹配,支持中文和英文。
Args:
keyword: 搜索关键词
scene_id: 场景ID
Returns:
List[OntologyClass]: 匹配的类型列表
Examples:
>>> repo = OntologyClassRepository(db)
>>> classes = repo.search_by_name("患者", scene_id)
"""
try:
logger.debug(
f"Searching ontology classes by keyword - "
f"keyword={keyword}, scene_id={scene_id}"
)
# 使用 ilike 进行不区分大小写的模糊匹配
classes = self.db.query(OntologyClass).filter(
OntologyClass.class_name.ilike(f"%{keyword}%"),
OntologyClass.scene_id == scene_id
).order_by(
OntologyClass.created_at.desc()
).all()
logger.info(
f"Found {len(classes)} ontology classes matching keyword '{keyword}' "
f"in scene {scene_id}"
)
return classes
except Exception as e:
logger.error(
f"Failed to search ontology classes by keyword: {str(e)}",
exc_info=True
)
raise
def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]:
"""获取场景下的所有类型
按创建时间倒序排列。
Args:
scene_id: 场景ID
Returns:
List[OntologyClass]: 类型列表
Examples:
>>> repo = OntologyClassRepository(db)
>>> classes = repo.get_by_scene(scene_id)
"""
try:
logger.debug(f"Getting ontology classes by scene: {scene_id}")
classes = self.db.query(OntologyClass).filter(
OntologyClass.scene_id == scene_id
).order_by(
OntologyClass.created_at.desc()
).all()
logger.info(
f"Found {len(classes)} ontology classes in scene {scene_id}"
)
return classes
except Exception as e:
logger.error(
f"Failed to get ontology classes by scene: {str(e)}",
exc_info=True
)
raise
def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]:
"""更新类型信息
Args:
class_id: 类型ID
update_data: 更新数据字典
Returns:
Optional[OntologyClass]: 更新后的类型对象不存在则返回None
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.update(
... class_id,
... {"class_name": "新名称"}
... )
"""
try:
logger.info(f"Updating ontology class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.warning(f"Ontology class not found for update: {class_id}")
return None
# 更新字段
if "class_name" in update_data and update_data["class_name"] is not None:
ontology_class.class_name = update_data["class_name"]
if "class_description" in update_data:
ontology_class.class_description = update_data["class_description"]
self.db.flush()
logger.info(f"Ontology class updated successfully: {class_id}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to update ontology class: {str(e)}",
exc_info=True
)
raise
def delete(self, class_id: UUID) -> bool:
"""删除类型
Args:
class_id: 类型ID
Returns:
bool: 删除成功返回True类型不存在返回False
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> success = repo.delete(class_id)
"""
try:
logger.info(f"Deleting ontology class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.warning(f"Ontology class not found for delete: {class_id}")
return False
self.db.delete(ontology_class)
self.db.flush()
logger.info(f"Ontology class deleted successfully: {class_id}")
return True
except Exception as e:
logger.error(
f"Failed to delete ontology class: {str(e)}",
exc_info=True
)
raise
def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool:
"""检查类型是否属于指定工作空间(通过场景关联)
Args:
class_id: 类型ID
workspace_id: 工作空间ID
Returns:
bool: 属于返回True否则返回False
Examples:
>>> repo = OntologyClassRepository(db)
>>> is_owner = repo.check_ownership(class_id, workspace_id)
"""
try:
logger.debug(
f"Checking class ownership - "
f"class_id={class_id}, workspace_id={workspace_id}"
)
count = self.db.query(OntologyClass).join(
OntologyScene,
OntologyClass.scene_id == OntologyScene.scene_id
).filter(
OntologyClass.class_id == class_id,
OntologyScene.workspace_id == workspace_id
).count()
is_owner = count > 0
logger.debug(
f"Class ownership check result: {is_owner} - "
f"class_id={class_id}"
)
return is_owner
except Exception as e:
logger.error(
f"Failed to check class ownership: {str(e)}",
exc_info=True
)
raise
def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]:
"""根据类型ID获取所属场景ID
Args:
class_id: 类型ID
Returns:
Optional[UUID]: 场景ID类型不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> scene_id = repo.get_scene_id_by_class(class_id)
"""
try:
logger.debug(f"Getting scene ID by class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.debug(f"Class not found: {class_id}")
return None
logger.debug(
f"Found scene ID: {ontology_class.scene_id} for class: {class_id}"
)
return ontology_class.scene_id
except Exception as e:
logger.error(
f"Failed to get scene ID by class: {str(e)}",
exc_info=True
)
raise

View File

@@ -0,0 +1,439 @@
# -*- coding: utf-8 -*-
"""本体场景Repository层
本模块提供本体场景的数据访问层实现。
Classes:
OntologySceneRepository: 本体场景数据访问类
"""
import logging
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.ontology_scene import OntologyScene
logger = get_db_logger()
class OntologySceneRepository:
"""本体场景Repository
提供本体场景的CRUD操作和权限检查。
Attributes:
db: SQLAlchemy数据库会话
"""
def __init__(self, db: Session):
"""初始化Repository
Args:
db: SQLAlchemy数据库会话
"""
self.db = db
def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene:
"""创建本体场景
Args:
scene_data: 场景数据字典包含scene_name和scene_description
workspace_id: 所属工作空间ID
Returns:
OntologyScene: 创建的场景对象
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.create(
... {"scene_name": "医疗场景", "scene_description": "描述"},
... workspace_id
... )
"""
try:
logger.info(
f"Creating ontology scene - "
f"name={scene_data.get('scene_name')}, "
f"workspace_id={workspace_id}"
)
scene = OntologyScene(
scene_name=scene_data.get("scene_name"),
scene_description=scene_data.get("scene_description"),
workspace_id=workspace_id
)
self.db.add(scene)
self.db.flush() # 获取ID但不提交
logger.info(
f"Ontology scene created successfully - "
f"scene_id={scene.scene_id}"
)
return scene
except Exception as e:
logger.error(
f"Failed to create ontology scene: {str(e)}",
exc_info=True
)
raise
def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]:
"""根据ID获取场景
Args:
scene_id: 场景ID
Returns:
Optional[OntologyScene]: 场景对象不存在则返回None
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.get_by_id(scene_id)
"""
try:
logger.debug(f"Getting ontology scene by ID: {scene_id}")
scene = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id
).first()
if scene:
logger.debug(f"Ontology scene found: {scene_id}")
else:
logger.debug(f"Ontology scene not found: {scene_id}")
return scene
except Exception as e:
logger.error(
f"Failed to get ontology scene by ID: {str(e)}",
exc_info=True
)
raise
def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]:
"""根据场景名称和工作空间ID获取场景精确匹配
Args:
scene_name: 场景名称
workspace_id: 工作空间ID
Returns:
Optional[OntologyScene]: 场景对象不存在则返回None
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.get_by_name("医疗场景", workspace_id)
"""
try:
logger.debug(
f"Getting ontology scene by name - "
f"scene_name={scene_name}, workspace_id={workspace_id}"
)
scene = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.scene_name == scene_name,
OntologyScene.workspace_id == workspace_id
).first()
if scene:
logger.debug(f"Ontology scene found: {scene_name}")
else:
logger.debug(f"Ontology scene not found: {scene_name}")
return scene
except Exception as e:
logger.error(
f"Failed to get ontology scene by name: {str(e)}",
exc_info=True
)
raise
def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]:
"""根据关键词模糊搜索场景
使用 LIKE 进行模糊匹配,支持中文和英文。
Args:
keyword: 搜索关键词
workspace_id: 工作空间ID
Returns:
List[OntologyScene]: 匹配的场景列表
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.search_by_name("医疗", workspace_id)
"""
try:
logger.debug(
f"Searching ontology scenes by keyword - "
f"keyword={keyword}, workspace_id={workspace_id}"
)
# 使用 ilike 进行不区分大小写的模糊匹配
scenes = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.scene_name.ilike(f"%{keyword}%"),
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
logger.info(
f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' "
f"in workspace {workspace_id}"
)
return scenes
except Exception as e:
logger.error(
f"Failed to search ontology scenes by keyword: {str(e)}",
exc_info=True
)
raise
def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple:
"""获取工作空间下的所有场景(支持分页)
使用joinedload预加载classes关系以统计数量。
Args:
workspace_id: 工作空间ID
page: 页码可选从1开始
page_size: 每页数量(可选)
Returns:
tuple: (场景列表, 总数量)
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes, total = repo.get_by_workspace(workspace_id)
>>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10)
"""
try:
logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}")
# 构建基础查询
query = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
)
# 获取总数
total = query.count()
# 如果提供了分页参数,应用分页
if page is not None and page_size is not None:
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
logger.debug(f"Applying pagination: offset={offset}, limit={page_size}")
scenes = query.all()
logger.info(
f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}"
)
return scenes, total
except Exception as e:
logger.error(
f"Failed to get ontology scenes by workspace: {str(e)}",
exc_info=True
)
raise
def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]:
"""更新场景信息
Args:
scene_id: 场景ID
update_data: 更新数据字典
Returns:
Optional[OntologyScene]: 更新后的场景对象不存在则返回None
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.update(
... scene_id,
... {"scene_name": "新名称"}
... )
"""
try:
logger.info(f"Updating ontology scene: {scene_id}")
scene = self.get_by_id(scene_id)
if not scene:
logger.warning(f"Ontology scene not found for update: {scene_id}")
return None
# 更新字段
if "scene_name" in update_data and update_data["scene_name"] is not None:
scene.scene_name = update_data["scene_name"]
if "scene_description" in update_data:
scene.scene_description = update_data["scene_description"]
self.db.flush()
logger.info(f"Ontology scene updated successfully: {scene_id}")
return scene
except Exception as e:
logger.error(
f"Failed to update ontology scene: {str(e)}",
exc_info=True
)
raise
def delete(self, scene_id: UUID) -> bool:
"""删除场景(级联删除类型)
依赖数据库级联删除配置ondelete="CASCADE")。
Args:
scene_id: 场景ID
Returns:
bool: 删除成功返回True场景不存在返回False
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> success = repo.delete(scene_id)
"""
try:
logger.info(f"Deleting ontology scene: {scene_id}")
scene = self.get_by_id(scene_id)
if not scene:
logger.warning(f"Ontology scene not found for delete: {scene_id}")
return False
self.db.delete(scene)
self.db.flush()
logger.info(
f"Ontology scene deleted successfully (cascade): {scene_id}"
)
return True
except Exception as e:
logger.error(
f"Failed to delete ontology scene: {str(e)}",
exc_info=True
)
raise
def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool:
"""检查场景是否属于指定工作空间
Args:
scene_id: 场景ID
workspace_id: 工作空间ID
Returns:
bool: 属于返回True否则返回False
Examples:
>>> repo = OntologySceneRepository(db)
>>> is_owner = repo.check_ownership(scene_id, workspace_id)
"""
try:
logger.debug(
f"Checking scene ownership - "
f"scene_id={scene_id}, workspace_id={workspace_id}"
)
count = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id,
OntologyScene.workspace_id == workspace_id
).count()
is_owner = count > 0
logger.debug(
f"Scene ownership check result: {is_owner} - "
f"scene_id={scene_id}"
)
return is_owner
except Exception as e:
logger.error(
f"Failed to check scene ownership: {str(e)}",
exc_info=True
)
raise
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
"""获取场景简单列表仅包含scene_id和scene_name用于下拉选择
这是一个轻量级查询不加载关联的classes响应速度快。
Args:
workspace_id: 工作空间ID
Returns:
List[dict]: 场景简单列表每项包含scene_id和scene_name
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.get_simple_list(workspace_id)
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
"""
try:
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
# 只查询需要的字段,不加载关联数据
results = self.db.query(
OntologyScene.scene_id,
OntologyScene.scene_name
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
scenes = [
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
for r in results
]
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
return scenes
except Exception as e:
logger.error(
f"Failed to get simple scene list: {str(e)}",
exc_info=True
)
raise

View File

@@ -4,7 +4,10 @@ from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.prompt_optimizer_model import (
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
PromptOptimizerSession,
PromptOptimizerSessionHistory,
RoleType,
PromptHistory
)
db_logger = get_db_logger()
@@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository:
def __init__(self, db: Session):
self.db = db
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
).first()
return session
def create_session(
self,
tenant_id: uuid.UUID,
@@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository:
user_id=user_id,
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
return session
except Exception as e:
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
raise
def get_session_history(
@@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository:
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id
).first()
if not session:
return []
history = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session.id,
PromptOptimizerSessionHistory.user_id == user_id
@@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository:
PromptOptimizerSession.user_id == user_id,
PromptOptimizerSession.tenant_id == tenant_id
).first()
if not session:
db_logger.error(f"Session {session_id} not found for user {user_id}")
raise ValueError(f"Session {session_id} not found for user {user_id}")
message = PromptOptimizerSessionHistory(
tenant_id=tenant_id,
session_id=session.id,
@@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository:
content=content,
)
self.db.add(message)
self.db.commit()
return message
except Exception as e:
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
raise
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
"""
Get the first user message from a session.
Args:
session_id (uuid.UUID): The session ID.
Returns:
str | None: The content of the first user message, or None if not found.
"""
try:
message = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session_id,
PromptOptimizerSessionHistory.role == RoleType.USER.value
).order_by(
PromptOptimizerSessionHistory.created_at.asc()
).first()
return message.content if message else None
except Exception as e:
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
raise
class PromptReleaseRepository:
def __init__(self, db: Session):
self.db = db
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
prompt_obj = self.db.query(PromptHistory).filter(
PromptHistory.session_id == session_id,
PromptHistory.is_delete.is_(False)
).first()
return prompt_obj
def create_prompt_release(
self,
tenant_id: uuid.UUID,
title: str,
session_id: uuid.UUID,
prompt: str,
) -> PromptHistory:
try:
prompt_obj = PromptHistory(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt,
)
self.db.add(prompt_obj)
return prompt_obj
except Exception as e:
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
raise
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
"""
Soft delete a prompt release by setting is_delete flag to True.
Args:
prompt_obj (PromptHistory): The prompt release object to delete.
"""
try:
prompt_obj.is_delete = True
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
except Exception as e:
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
raise
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
"""
Get a prompt release by its ID.
Args:
prompt_id (uuid.UUID): The prompt release ID.
Returns:
PromptHistory | None: The prompt release object or None if not found.
"""
try:
prompt_obj = self.db.query(PromptHistory).filter(
PromptHistory.id == prompt_id
).first()
return prompt_obj
except Exception as e:
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
raise
def count_prompts(self, tenant_id: uuid.UUID) -> int:
"""
Count total number of non-deleted prompts for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
Returns:
int: Total count of prompts.
"""
try:
count = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False)
).count()
return count
except Exception as e:
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
raise
def get_prompts_paginated(
self,
tenant_id: uuid.UUID,
offset: int,
limit: int
) -> list[PromptHistory]:
"""
Get paginated list of prompt releases for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
offset (int): Number of records to skip.
limit (int): Maximum number of records to return.
Returns:
list[PromptHistory]: List of prompt releases.
"""
try:
prompts = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False)
).order_by(
PromptHistory.created_at.desc()
).offset(offset).limit(limit).all()
return prompts
except Exception as e:
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
raise
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
"""
Count total number of non-deleted prompts matching keyword for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
keyword (str): Search keyword for title.
Returns:
int: Total count of matching prompts.
"""
try:
count = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False),
PromptHistory.title.ilike(f"%{keyword}%")
).count()
return count
except Exception as e:
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
raise
def search_prompts_paginated(
self,
tenant_id: uuid.UUID,
keyword: str,
offset: int,
limit: int
) -> list[PromptHistory]:
"""
Search prompt releases by keyword in title with pagination.
Args:
tenant_id (uuid.UUID): The tenant ID.
keyword (str): Search keyword for title.
offset (int): Number of records to skip.
limit (int): Maximum number of records to return.
Returns:
list[PromptHistory]: List of matching prompt releases.
"""
try:
prompts = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False),
PromptHistory.title.ilike(f"%{keyword}%")
).order_by(
PromptHistory.created_at.desc()
).offset(offset).limit(limit).all()
return prompts
except Exception as e:
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
raise

View File

@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
kb_id: str = Field(..., description="知识库ID")
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
retrieve_type: str = Field(default="hybrid", description="检索方式participle semantichybrid")

View File

@@ -1,3 +1,4 @@
from abc import ABC
from typing import Optional
from pydantic import BaseModel
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel):
messages: list[dict]
end_user_id: str
config_id: Optional[str] = None
config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
STORAGE_RAG = "rag"
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6

View File

@@ -1,5 +1,7 @@
import uuid
from pydantic import BaseModel, Field
from typing import Optional
from typing import Optional, Union
from uuid import UUID
from enum import Enum
@@ -10,7 +12,7 @@ class OptimizationStrategy(str, Enum):
ACCURACY_FIRST = "accuracy_first"
BALANCED = "balanced"
class Memory_Reflection(BaseModel):
config_id: Optional[UUID] = None
config_id: Union[uuid.UUID, int, str] = None
reflection_enabled: bool
reflection_period_in_hours: str
reflexion_range: Optional[str] = "partial"

View File

@@ -147,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: uuid.UUID = Field("config_id", description="配置唯一标识UUID")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
@@ -229,26 +229,32 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
config_desc: str = Field("配置描述", description="配置描述(字符串)")
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间IDUUID")
# 本体场景关联(可选)
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表")
# 模型配置字段(可选,用于手动指定或自动填充)
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id: uuid.UUID = Field("配置ID", description="配置IDUUID")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[uuid.UUID] = None
config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)")
config_id: Union[uuid.UUID, int, str] = None
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id: Optional[uuid.UUID] = None
config_id:Union[uuid.UUID, int, str] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
@@ -315,14 +321,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型
config_id: Optional[uuid.UUID] = None
config_id:Union[uuid.UUID, int, str] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id: uuid.UUID = Field(..., description="配置ID唯一")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -330,7 +336,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Optional[uuid.UUID] = None
config_id: Union[uuid.UUID, int, str] = None
user_id: Optional[str] = None
apply_id: Optional[str] = None
@@ -406,7 +412,7 @@ class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: uuid.UUID = Field(..., description="配置ID")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数")
lambda_mem: float = Field(..., description="记忆衰减参数")
@@ -423,8 +429,8 @@ class ForgettingConfigResponse(BaseModel):
class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: uuid.UUID = Field(..., description="配置ID")
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -499,7 +505,7 @@ class ForgettingCurveRequest(BaseModel):
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[uuid.UUID] = Field(None, description="配置ID可选如果为None则使用默认配置")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
class ForgettingCurveResponse(BaseModel):

View File

@@ -3,14 +3,12 @@ from typing import Optional, List, Dict, Any
import datetime
import uuid
from app.models.models_model import ModelProvider, ModelType
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
from app.core.logging_config import get_business_logger
schema_logger = get_business_logger()
# ModelConfig Schemas
class ModelConfigBase(BaseModel):
"""模型配置基础Schema"""
@@ -22,6 +20,7 @@ class ModelConfigBase(BaseModel):
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开")
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
class ApiKeyCreateNested(BaseModel):
@@ -44,13 +43,14 @@ class ModelConfigCreate(ModelConfigBase):
class CompositeModelCreate(BaseModel):
"""创建组合模型Schema"""
name: str = Field(..., description="组合模型名称", max_length=255)
type: ModelType = Field(..., description="模型类型")
type: Optional[ModelType] = Field(None, description="模型类型")
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
description: Optional[str] = Field(None, description="模型描述")
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开")
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
load_balance_strategy: Optional[str] = Field(default=LoadBalanceStrategy.NONE.value, description="负载均衡策略")
class ModelConfigUpdate(BaseModel):

View File

@@ -0,0 +1,461 @@
"""本体提取API的请求和响应模型
本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。
Classes:
ExtractionRequest: 本体提取请求模型
ExtractionResponse: 本体提取响应模型
ExportRequest: OWL文件导出请求模型
ExportResponse: OWL文件导出响应模型
OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳)
SceneCreateRequest: 场景创建请求模型
SceneUpdateRequest: 场景更新请求模型
SceneResponse: 场景响应模型
SceneListResponse: 场景列表响应模型
ClassCreateRequest: 类型创建请求模型
ClassUpdateRequest: 类型更新请求模型
ClassResponse: 类型响应模型
ClassListResponse: 类型列表响应模型
"""
from typing import List, Optional
import datetime
from uuid import UUID
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from app.core.memory.models.ontology_models import OntologyClass
class ExtractionRequest(BaseModel):
"""本体提取请求模型
用于POST /api/ontology/extract端点的请求体。
Attributes:
scenario: 场景描述文本,不能为空
domain: 可选的领域提示(如Healthcare, Education等)
llm_id: LLM模型ID,必须提供
scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景
Examples:
>>> request = ExtractionRequest(
... scenario="医院管理患者记录...",
... domain="Healthcare",
... llm_id="550e8400-e29b-41d4-a716-446655440000",
... scene_id="660e8400-e29b-41d4-a716-446655440000"
... )
"""
scenario: str = Field(..., description="场景描述文本", min_length=1)
domain: Optional[str] = Field(None, description="可选的领域提示")
llm_id: str = Field(..., description="LLM模型ID")
scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景")
class ExtractionResponse(BaseModel):
"""本体提取响应模型
用于POST /api/ontology/extract端点的响应体。
Attributes:
classes: 提取的本体类列表
domain: 识别的领域
extracted_count: 提取的类数量
Examples:
>>> response = ExtractionResponse(
... classes=[...],
... domain="Healthcare",
... extracted_count=7
... )
"""
classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表")
domain: str = Field(..., description="识别的领域")
extracted_count: int = Field(..., description="提取的类数量")
class ExportRequest(BaseModel):
"""OWL文件导出请求模型
用于POST /api/ontology/export端点的请求体。
Attributes:
classes: 要导出的本体类列表
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
Examples:
>>> request = ExportRequest(
... classes=[...],
... format="rdfxml",
... include_metadata=True
... )
"""
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
class ExportResponse(BaseModel):
"""OWL文件导出响应模型
用于POST /api/ontology/export端点的响应体。
Attributes:
owl_content: OWL文件内容
format: 导出格式
classes_count: 导出的类数量
Examples:
>>> response = ExportResponse(
... owl_content="<?xml version='1.0'?>...",
... format="rdfxml",
... classes_count=7
... )
"""
owl_content: str = Field(..., description="OWL文件内容")
format: str = Field(..., description="导出格式")
classes_count: int = Field(..., description="导出的类数量")
class OntologyResultResponse(BaseModel):
"""本体提取结果响应模型
用于返回数据库中存储的提取结果,时间戳为毫秒级。
Attributes:
id: 结果ID (UUID)
scenario: 场景描述文本
domain: 领域
classes_json: 提取的本体类数据(JSON格式)
extracted_count: 提取的类数量
user_id: 用户ID
created_at: 创建时间(毫秒时间戳)
Examples:
>>> response = OntologyResultResponse(
... id=uuid.uuid4(),
... scenario="医院管理患者记录...",
... domain="Healthcare",
... classes_json={"classes": [...]},
... extracted_count=7,
... user_id=123,
... created_at=datetime.now()
... )
"""
id: UUID = Field(..., description="结果ID")
scenario: str = Field(..., description="场景描述文本")
domain: Optional[str] = Field(None, description="领域")
classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)")
extracted_count: int = Field(..., description="提取的类数量")
user_id: Optional[int] = Field(None, description="用户ID")
created_at: datetime.datetime = Field(..., description="创建时间")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
class Config:
from_attributes = True
# ==================== 本体场景相关 Schema ====================
class SceneCreateRequest(BaseModel):
"""场景创建请求模型
用于创建新的本体场景。
Attributes:
scene_name: 场景名称必填1-200字符
scene_description: 场景描述,可选
Examples:
>>> request = SceneCreateRequest(
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模"
... )
"""
scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
class SceneUpdateRequest(BaseModel):
"""场景更新请求模型
用于更新已有本体场景信息。
Attributes:
scene_name: 场景名称可选1-200字符
scene_description: 场景描述,可选
Examples:
>>> request = SceneUpdateRequest(
... scene_name="更新后的场景名称",
... scene_description="更新后的描述"
... )
"""
scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
class SceneResponse(BaseModel):
"""场景响应模型
用于返回本体场景信息。
Attributes:
scene_id: 场景ID
scene_name: 场景名称
scene_description: 场景描述
type_num: 类型数量
workspace_id: 所属工作空间ID
created_at: 创建时间(毫秒时间戳)
updated_at: 更新时间(毫秒时间戳)
classes_count: 类型数量
Examples:
>>> response = SceneResponse(
... scene_id=uuid.uuid4(),
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模",
... type_num=0,
... workspace_id=uuid.uuid4(),
... created_at=datetime.now(),
... updated_at=datetime.now(),
... classes_count=5
... )
"""
scene_id: UUID = Field(..., description="场景ID")
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
type_num: int = Field(..., description="类型数量")
entity_type: Optional[List[str]] = Field(None, description="实体类型列表最多3个class_name")
workspace_id: UUID = Field(..., description="所属工作空间ID")
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
classes_count: int = Field(0, description="类型数量")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
"""将更新时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
model_config = ConfigDict(from_attributes=True)
class PaginationInfo(BaseModel):
"""分页信息模型
Attributes:
page: 当前页码
pagesize: 每页数量
total: 总数量
hasnext: 是否有下一页
"""
page: int = Field(..., description="当前页码")
pagesize: int = Field(..., description="每页数量")
total: int = Field(..., description="总数量")
hasnext: bool = Field(..., description="是否有下一页")
class SceneListResponse(BaseModel):
"""场景列表响应模型(支持分页)
用于返回本体场景列表。
Attributes:
items: 场景列表
page: 分页信息(可选,分页时返回)
Examples:
>>> # 不分页
>>> response = SceneListResponse(
... items=[scene1, scene2]
... )
>>> # 分页
>>> response = SceneListResponse(
... items=[scene1, scene2, ...],
... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True)
... )
"""
items: List[SceneResponse] = Field(..., description="场景列表")
page: Optional[PaginationInfo] = Field(None, description="分页信息")
# ==================== 本体类型相关 Schema ====================
class ClassItem(BaseModel):
"""单个类型信息模型
Attributes:
class_name: 类型名称必填1-200字符
class_description: 类型描述,可选
Examples:
>>> item = ClassItem(
... class_name="患者",
... class_description="医院患者信息"
... )
"""
class_name: str = Field(..., min_length=1, max_length=200, description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
class ClassCreateRequest(BaseModel):
"""类型创建请求模型(统一使用列表形式)
通过列表中元素数量决定创建模式:
- 列表包含 1 个元素:单个创建
- 列表包含多个元素:批量创建
Attributes:
scene_id: 所属场景ID必填
classes: 类型列表,必填,至少包含 1 个元素
Examples:
# 单个创建(列表中 1 个元素)
>>> request = ClassCreateRequest(
... scene_id=uuid.uuid4(),
... classes=[
... ClassItem(class_name="患者", class_description="医院患者信息")
... ]
... )
# 批量创建(列表中多个元素)
>>> request = ClassCreateRequest(
... scene_id=uuid.uuid4(),
... classes=[
... ClassItem(class_name="患者", class_description="医院患者信息"),
... ClassItem(class_name="医生", class_description="医院医生信息"),
... ClassItem(class_name="药品", class_description="医院药品信息")
... ]
... )
"""
scene_id: UUID = Field(..., description="所属场景ID")
classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素")
class ClassUpdateRequest(BaseModel):
"""类型更新请求模型
用于更新已有本体类型信息。
Attributes:
class_name: 类型名称可选1-200字符
class_description: 类型描述,可选
Examples:
>>> request = ClassUpdateRequest(
... class_name="更新后的类型名称",
... class_description="更新后的描述"
... )
"""
class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
class ClassResponse(BaseModel):
"""类型响应模型
用于返回本体类型信息。
Attributes:
class_id: 类型ID
class_name: 类型名称
class_description: 类型描述
scene_id: 所属场景ID
created_at: 创建时间(毫秒时间戳)
updated_at: 更新时间(毫秒时间戳)
Examples:
>>> response = ClassResponse(
... class_id=uuid.uuid4(),
... class_name="患者",
... class_description="医院患者信息",
... scene_id=uuid.uuid4(),
... created_at=datetime.now(),
... updated_at=datetime.now()
... )
"""
class_id: UUID = Field(..., description="类型ID")
class_name: str = Field(..., description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
scene_id: UUID = Field(..., description="所属场景ID")
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
"""将更新时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
model_config = ConfigDict(from_attributes=True)
class ClassBatchCreateResponse(BaseModel):
"""批量创建类型响应模型
用于返回批量创建的结果统计和详情。
Attributes:
total: 总共尝试创建的数量
success_count: 成功创建的数量
failed_count: 失败的数量
items: 成功创建的类型列表
errors: 失败的错误信息列表(可选)
Examples:
>>> response = ClassBatchCreateResponse(
... total=3,
... success_count=2,
... failed_count=1,
... items=[class1, class2],
... errors=["创建类型 '药品' 失败: 类型名称已存在"]
... )
"""
total: int = Field(..., description="总共尝试创建的数量")
success_count: int = Field(..., description="成功创建的数量")
failed_count: int = Field(0, description="失败的数量")
items: List[ClassResponse] = Field(..., description="成功创建的类型列表")
errors: Optional[List[str]] = Field(None, description="失败的错误信息列表")
class ClassListResponse(BaseModel):
"""类型列表响应模型
用于返回本体类型列表。
Attributes:
total: 总数量
scene_id: 所属场景ID
scene_name: 场景名称
scene_description: 场景描述
items: 类型列表
Examples:
>>> response = ClassListResponse(
... total=3,
... scene_id=uuid.uuid4(),
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模",
... items=[class1, class2, class3]
... )
"""
total: int = Field(..., description="总数量")
scene_id: UUID = Field(..., description="所属场景ID")
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
items: List[ClassResponse] = Field(..., description="类型列表")

View File

@@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel):
)
class PromptSaveRequest(BaseModel):
session_id: UUID = Field(
...,
description="Session ID"
)
title: str = Field(
...,
description="Prompt Title"
)
prompt: str = Field(
...,
description="Optimized prompt content"
)
class PromptOptModelSet(BaseModel):
id: UUID | None = Field(
default=None,

View File

@@ -171,7 +171,14 @@ class AppChatService:
self.conversation_service.save_conversation_messages(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"]
assistant_message=result["content"],
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
elapsed_time = time.time() - start_time
@@ -310,6 +317,7 @@ class AppChatService:
# 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
@@ -320,9 +328,12 @@ class AppChatService:
config_id=config_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
@@ -339,7 +350,7 @@ class AppChatService:
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)
@@ -416,7 +427,11 @@ class AppChatService:
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"sub_results": result.get("sub_results")
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
@@ -458,6 +473,7 @@ class AppChatService:
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
full_content = ""
total_tokens = 0
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
@@ -474,16 +490,26 @@ class AppChatService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
if "sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
except:
pass
else:
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
@@ -499,7 +525,12 @@ class AppChatService:
role="assistant",
content=full_content,
meta_data={
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)

View File

@@ -187,7 +187,7 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["count"] for row in daily_data)
return {"daily": daily_data, "total": total}

View File

@@ -1,4 +1,5 @@
"""会话服务"""
import os
import uuid
from datetime import datetime, timedelta
from typing import Annotated
@@ -298,7 +299,8 @@ class ConversationService:
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str
assistant_message: str,
meta_data: Optional[dict] = None
):
"""
Save a pair of user and assistant messages to the conversation.
@@ -307,6 +309,7 @@ class ConversationService:
conversation_id (uuid.UUID): Conversation UUID.
user_message (str): User's message content.
assistant_message (str): Assistant's response content.
meta_data (Optional[dict]): Optional metadata for the messages.
"""
self.add_message(
conversation_id=conversation_id,
@@ -317,7 +320,8 @@ class ConversationService:
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
content=assistant_message,
meta_data=meta_data
)
logger.debug(
@@ -526,12 +530,12 @@ class ConversationService:
takeaways=[],
info_score=0,
)
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
system_prompt = f.read()
rendered_system_message = Template(system_prompt).render()
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
user_prompt = f.read()
rendered_user_message = Template(user_prompt).render(
language=language,

View File

@@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
@@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"content_length": len(str(memory_content))
}
)
return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
@@ -442,7 +443,14 @@ class DraftRunService:
user_message=message,
assistant_message=result["content"],
app_id=agent_config.app_id,
user_id=user_id
user_id=user_id,
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
response = {
@@ -649,6 +657,7 @@ class DraftRunService:
# 9. 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
@@ -659,14 +668,22 @@ class DraftRunService:
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield self._format_sse_event("message", {
"content": chunk
})
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield self._format_sse_event("message", {
"content": chunk
})
elapsed_time = time.time() - start_time
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
})
# 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
@@ -674,7 +691,10 @@ class DraftRunService:
user_message=message,
assistant_message=full_content,
app_id=agent_config.app_id,
user_id=user_id
user_id=user_id,
meta_data={
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)
# 11. 发送结束事件
@@ -898,6 +918,7 @@ class DraftRunService:
conversation_id: str,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
@@ -909,6 +930,7 @@ class DraftRunService:
assistant_message: AI 回复消息
app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗
"""
try:
from app.services.conversation_service import ConversationService
@@ -927,7 +949,8 @@ class DraftRunService:
conversation_service.add_message(
conversation_id=conv_uuid,
role="assistant",
content=assistant_message
content=assistant_message,
meta_data=meta_data
)
logger.debug(

View File

@@ -17,12 +17,15 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
logger = get_business_logger()
class EmotionSuggestion(BaseModel):
"""情绪建议模型"""
type: str = Field(..., description="建议类型emotion_balance/activity_recommendation/social_connection/stress_management")
type: str = Field(...,
description="建议类型emotion_balance/activity_recommendation/social_connection/stress_management")
title: str = Field(..., description="建议标题")
content: str = Field(..., description="建议内容")
priority: str = Field(..., description="优先级high/medium/low")
@@ -37,33 +40,33 @@ class EmotionSuggestionsResponse(BaseModel):
class EmotionAnalyticsService:
"""情绪分析服务
提供情绪数据的分析和统计功能,包括:
- 情绪标签统计
- 情绪词云数据
- 情绪健康指数计算
- 个性化情绪建议生成
Attributes:
emotion_repo: 情绪数据仓储实例
"""
def __init__(self):
"""初始化情绪分析服务"""
connector = Neo4jConnector()
self.emotion_repo = EmotionRepository(connector)
logger.info("情绪分析服务初始化完成")
async def get_emotion_tags(
self,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = 10
self,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = 10
) -> Dict[str, Any]:
"""获取情绪标签统计
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
确保返回所有6个情绪维度joy、sadness、anger、fear、surprise、neutral
即使某些维度没有数据也会返回count=0的记录。
@@ -71,8 +74,8 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
f"start={start_date}, end={end_date}, limit={limit}")
f"start={start_date}, end={end_date}, limit={limit}")
# 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags(
end_user_id=end_user_id,
@@ -81,13 +84,13 @@ class EmotionAnalyticsService:
end_date=end_date,
limit=limit
)
# 定义所有6个情绪维度
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
# 将查询结果转换为字典,方便查找
tags_dict = {tag["emotion_type"]: tag for tag in tags}
# 补全缺失的情绪维度
complete_tags = []
for emotion in all_emotion_types:
@@ -101,52 +104,52 @@ class EmotionAnalyticsService:
"percentage": 0.0,
"avg_intensity": 0.0
})
# 计算总数
total_count = sum(tag["count"] for tag in complete_tags)
# 如果有数据重新计算百分比因为补全了0值项
if total_count > 0:
for tag in complete_tags:
if tag["count"] > 0:
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
# 构建时间范围信息
time_range = {}
if start_date:
time_range["start_date"] = start_date
if end_date:
time_range["end_date"] = end_date
# 格式化响应
response = {
"tags": complete_tags,
"total_count": total_count,
"time_range": time_range if time_range else None
}
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
return response
except Exception as e:
logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True)
raise
async def get_emotion_wordcloud(
self,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
self,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> Dict[str, Any]:
"""获取情绪词云数据
查询情绪关键词及其频率,用于生成词云可视化。
Args:
end_user_id: 宿主ID用户组ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
Returns:
Dict: 包含情绪词云数据的响应:
- keywords: 关键词列表
@@ -154,39 +157,39 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}")
# 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud(
end_user_id=end_user_id,
emotion_type=emotion_type,
limit=limit
)
# 计算总关键词数量
total_keywords = len(keywords)
# 格式化响应
response = {
"keywords": keywords,
"total_keywords": total_keywords
}
logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}")
return response
except Exception as e:
logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True)
raise
def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算积极率
根据情绪类型分类正面、负面和中性情绪,计算积极率。
公式:(正面数 / (正面数 + 负面数)) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_type 字段
Returns:
Dict: 包含积极率计算结果:
- score: 积极率分数0-100
@@ -197,38 +200,38 @@ class EmotionAnalyticsService:
# 定义情绪分类
positive_emotions = {'joy', 'surprise'}
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计各类情绪数量
positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions)
negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions)
neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral')
# 计算积极率
total_non_neutral = positive_count + negative_count
if total_non_neutral > 0:
score = (positive_count / total_non_neutral) * 100
else:
score = 50.0 # 如果没有非中性情绪默认为50
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
f"neutral={neutral_count}, score={score:.2f}")
f"neutral={neutral_count}, score={score:.2f}")
return {
"score": round(score, 2),
"positive_count": positive_count,
"negative_count": negative_count,
"neutral_count": neutral_count
}
def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算稳定性
基于情绪强度的标准差计算情绪稳定性。
公式:(1 - min(std_deviation, 1.0)) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_intensity 字段
Returns:
Dict: 包含稳定性计算结果:
- score: 稳定性分数0-100
@@ -236,7 +239,7 @@ class EmotionAnalyticsService:
"""
# 提取所有情绪强度
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
# 计算标准差
if len(intensities) >= 2:
std_deviation = statistics.stdev(intensities)
@@ -244,29 +247,29 @@ class EmotionAnalyticsService:
std_deviation = 0.0 # 只有一个数据点标准差为0
else:
std_deviation = 0.0 # 没有数据标准差为0
# 计算稳定性分数
# 标准差越小,稳定性越高
score = (1 - min(std_deviation, 1.0)) * 100
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
return {
"score": round(score, 2),
"std_deviation": round(std_deviation, 3)
}
def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算恢复力
分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。
公式:(负面到正面转换次数 / 总负面情绪数) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段
应该按时间顺序排列
Returns:
Dict: 包含恢复力计算结果:
- score: 恢复力分数0-100
@@ -275,24 +278,24 @@ class EmotionAnalyticsService:
# 定义情绪分类
positive_emotions = {'joy', 'surprise'}
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计负面到正面的转换次数
recovery_count = 0
negative_count = 0
for i in range(len(emotions)):
current_emotion = emotions[i].get('emotion_type')
# 统计负面情绪总数
if current_emotion in negative_emotions:
negative_count += 1
# 检查下一个情绪是否为正面
if i + 1 < len(emotions):
next_emotion = emotions[i + 1].get('emotion_type')
if next_emotion in positive_emotions:
recovery_count += 1
# 计算恢复力分数
if negative_count > 0:
recovery_rate = recovery_count / negative_count
@@ -301,28 +304,28 @@ class EmotionAnalyticsService:
# 如果没有负面情绪恢复力设为100最佳状态
recovery_rate = 1.0
score = 100.0
logger.debug(f"恢复力计算: negative_count={negative_count}, "
f"recovery_count={recovery_count}, score={score:.2f}")
f"recovery_count={recovery_count}, score={score:.2f}")
return {
"score": round(score, 2),
"recovery_rate": round(recovery_rate, 3)
}
async def calculate_emotion_health_index(
self,
end_user_id: str,
time_range: str = "30d"
self,
end_user_id: str,
time_range: str = "30d"
) -> Dict[str, Any]:
"""计算情绪健康指数
综合积极率、稳定性和恢复力计算情绪健康指数。
Args:
end_user_id: 宿主ID用户组ID
time_range: 时间范围7d/30d/90d
Returns:
Dict: 包含情绪健康指数的完整响应:
- health_score: 综合健康分数0-100
@@ -336,13 +339,13 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}")
# 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range(
end_user_id=end_user_id,
time_range=time_range
)
# 如果没有数据,返回默认值
if not emotions:
logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据")
@@ -357,20 +360,20 @@ class EmotionAnalyticsService:
"emotion_distribution": {},
"time_range": time_range
}
# 计算各维度指标
positivity_rate = self._calculate_positivity_rate(emotions)
stability = self._calculate_stability(emotions)
resilience = self._calculate_resilience(emotions)
# 计算综合健康分数
# 公式positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
health_score = (
positivity_rate["score"] * 0.4 +
stability["score"] * 0.3 +
resilience["score"] * 0.3
positivity_rate["score"] * 0.4 +
stability["score"] * 0.3 +
resilience["score"] * 0.3
)
# 确定健康等级
if health_score >= 80:
level = "优秀"
@@ -380,13 +383,13 @@ class EmotionAnalyticsService:
level = "一般"
else:
level = "较差"
# 统计情绪分布
emotion_distribution = {}
for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']:
count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type)
emotion_distribution[emotion_type] = count
# 格式化响应
response = {
"health_score": round(health_score, 2),
@@ -399,22 +402,22 @@ class EmotionAnalyticsService:
"emotion_distribution": emotion_distribution,
"time_range": time_range
}
logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}")
return response
except Exception as e:
logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True)
raise
def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""分析情绪模式
识别主要负面情绪、情绪触发因素和波动时段。
Args:
emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段
Returns:
Dict: 包含情绪模式分析结果:
- dominant_negative_emotion: 主要负面情绪类型
@@ -422,19 +425,19 @@ class EmotionAnalyticsService:
- emotion_volatility: 情绪波动性(高/中/低)
"""
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计负面情绪分布
negative_emotion_counts = {}
for emotion in emotions:
emotion_type = emotion.get('emotion_type')
if emotion_type in negative_emotions:
negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1
# 识别主要负面情绪
dominant_negative_emotion = None
if negative_emotion_counts:
dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get)
# 识别高强度情绪(强度 >= 0.7
high_intensity_emotions = [
{
@@ -445,7 +448,7 @@ class EmotionAnalyticsService:
for e in emotions
if e.get('emotion_intensity', 0) >= 0.7
]
# 评估情绪波动性
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
if len(intensities) >= 2:
@@ -458,29 +461,29 @@ class EmotionAnalyticsService:
volatility = ""
else:
volatility = "未知"
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
return {
"dominant_negative_emotion": dominant_negative_emotion,
"high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个
"emotion_volatility": volatility
}
async def generate_emotion_suggestions(
self,
end_user_id: str,
db: Session,
self,
end_user_id: str,
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
基于情绪健康数据和用户画像生成个性化建议。
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
- health_summary: 健康状态摘要
@@ -488,17 +491,17 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
@@ -513,35 +516,35 @@ class EmotionAnalyticsService:
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
# 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range(
end_user_id=end_user_id,
time_range="30d"
)
# 4. 分析情绪模式
patterns = self._analyze_emotion_patterns(emotions)
# 5. 获取用户画像数据简化版直接从Neo4j获取
user_profile = await self._get_simple_user_profile(end_user_id)
# 6. 构建LLM prompt
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
{"role": "user", "content": prompt}
]
# 8. 使用结构化输出直接获取 Pydantic 模型
try:
suggestions_response = await llm_client.response_structured(
@@ -552,7 +555,7 @@ class EmotionAnalyticsService:
logger.error(f"LLM 结构化输出失败: {str(e)}")
# 返回默认建议
suggestions_response = self._get_default_suggestions(health_data)
# 8. 验证建议数量3-5条
if len(suggestions_response.suggestions) < 3:
logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}")
@@ -560,7 +563,7 @@ class EmotionAnalyticsService:
elif len(suggestions_response.suggestions) > 5:
logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}")
suggestions_response.suggestions = suggestions_response.suggestions[:5]
# 9. 格式化响应
response = {
"health_summary": suggestions_response.health_summary,
@@ -575,26 +578,26 @@ class EmotionAnalyticsService:
for s in suggestions_response.suggestions
]
}
logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}")
return response
except Exception as e:
logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True)
raise
async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]:
"""获取简化的用户画像数据
Args:
end_user_id: 用户ID
Returns:
Dict: 用户画像数据
"""
try:
connector = Neo4jConnector()
# 查询用户的实体和标签
query = """
MATCH (e:Entity)
@@ -603,59 +606,59 @@ class EmotionAnalyticsService:
ORDER BY e.created_at DESC
LIMIT 20
"""
entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
# 后期会引入用户的习惯。。
return {
"interests": interests if interests else ["未知"]
}
except Exception as e:
logger.error(f"获取用户画像失败: {str(e)}")
return {"interests": ["未知"]}
async def _build_suggestion_prompt(
self,
health_data: Dict[str, Any],
patterns: Dict[str, Any],
user_profile: Dict[str, Any]
self,
health_data: Dict[str, Any],
patterns: Dict[str, Any],
user_profile: Dict[str, Any]
) -> str:
"""构建情绪建议生成的prompt
Args:
health_data: 情绪健康数据
patterns: 情绪模式分析结果
user_profile: 用户画像数据
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,
patterns=patterns,
user_profile=user_profile
)
return prompt
def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse:
"""获取默认建议当LLM调用失败时使用
Args:
health_data: 情绪健康数据
Returns:
EmotionSuggestionsResponse: 默认建议
"""
health_score = health_data.get('health_score', 0)
if health_score >= 80:
summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。"
elif health_score >= 60:
@@ -664,7 +667,7 @@ class EmotionAnalyticsService:
summary = "您的情绪健康需要关注,建议采取一些改善措施。"
else:
summary = "您的情绪健康需要重点关注,建议寻求专业帮助。"
suggestions = [
EmotionSuggestion(
type="emotion_balance",
@@ -700,54 +703,54 @@ class EmotionAnalyticsService:
]
)
]
return EmotionSuggestionsResponse(
health_summary=summary,
suggestions=suggestions
)
async def get_cached_suggestions(
self,
end_user_id: str,
db: Session,
self,
end_user_id: str,
db: Session,
) -> Optional[Dict[str, Any]]:
"""从 Redis 缓存获取个性化情绪建议
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话(保留参数以保持接口兼容性)
Returns:
Dict: 缓存的建议数据,如果不存在或已过期返回 None
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
return None
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
return cached_data
except Exception as e:
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
return None
async def save_suggestions_cache(
self,
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
self,
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
) -> None:
"""保存建议到 Redis 缓存
Args:
end_user_id: 宿主ID用户组ID
suggestions_data: 建议数据
@@ -756,24 +759,24 @@ class EmotionAnalyticsService:
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
# 保存到 Redis
success = await EmotionMemoryCache.set_emotion_suggestions(
user_id=end_user_id,
suggestions_data=suggestions_data,
expire=expire_seconds
)
if success:
logger.info(f"建议缓存保存成功: user={end_user_id}")
else:
logger.warning(f"建议缓存保存失败: user={end_user_id}")
except Exception as e:
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程

View File

@@ -4,7 +4,7 @@ import uuid
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver
@@ -727,9 +727,12 @@ class HandoffsService:
# 提取响应
response_content = ""
total_tokens = 0
for msg in result.get("messages", []):
if isinstance(msg, AIMessage):
response_content = msg.content
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
return {
@@ -737,7 +740,12 @@ class HandoffsService:
"active_agent": result.get("active_agent"),
"response": response_content,
"message_count": len(result.get("messages", [])),
"handoff_count": result.get("handoff_count", 0)
"handoff_count": result.get("handoff_count", 0),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
async def chat_stream(
@@ -830,6 +838,12 @@ class HandoffsService:
# 捕获 LLM 结束事件,输出收集到的工具调用
elif kind == "on_chat_model_end":
output_message = event.get("data", {}).get("output", {})
if isinstance(output_message, AIMessageChunk):
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
if collected_tool_calls:
# 找到参数最完整的 transfer 工具调用
best_tc = None

View File

@@ -334,7 +334,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-')
print(langchain_messages)
print(100*'-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,

View File

@@ -53,7 +53,10 @@ def get_workspace_end_users(
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 updated_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
@@ -68,9 +71,14 @@ def get_workspace_end_users(
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 updated_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).order_by(
nullslast(desc(EndUserModel.updated_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)

View File

@@ -18,6 +18,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.models.app_model import App
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.utils.config_utils import resolve_config_id
api_logger = get_api_logger()
@@ -88,51 +89,44 @@ class WorkspaceAppService:
for release in app_releases:
memory_content = self._extract_memory_content(release.config)
if memory_content and memory_content in processed_configs:
continue
release_info = {
"app_id": str(release.app_id),
"config": memory_content
}
if memory_content:
processed_configs.add(memory_content)
memory_config_info = self._get_memory_config(memory_content)
if memory_config_info:
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
app_info["memory_configs"].append(memory_config_info)
app_info["releases"].append(release_info)
def _extract_memory_content(self, config: Any) -> str:
"""Extract memory_comtent from config"""
if not config or not isinstance(config, dict):
return None
memory_obj = config.get('memory')
if memory_obj and isinstance(memory_obj, dict):
return memory_obj.get('memory_content')
return None
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
"""Retrieve memory_config information based on memory_content"""
try:
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
memory_content = resolve_config_id(memory_content, self.db)
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content))
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
# if memory_config_result is None:
# return None
if memory_config_result:
return {
"config_id": memory_config_result.config_id,
"config_id": memory_content,
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
"iteration_period": memory_config_result.iteration_period,
"reflexion_range": memory_config_result.reflexion_range,
@@ -144,20 +138,22 @@ class WorkspaceAppService:
}
except Exception as e:
api_logger.warning(f"查询memory_config失败memory_content: {memory_content}, 错误: {str(e)}")
return None
def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None:
"""Processing end-user information for applications"""
end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all()
for end_user in end_users:
end_user_info = {
"id": str(end_user.id),
"app_id": str(end_user.app_id)
}
app_info["end_users"].append(end_user_info)
print(100*'-')
print(app_info)
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
"""
Read the reflection time of end users
@@ -176,7 +172,7 @@ class WorkspaceAppService:
except Exception as e:
api_logger.error(f"读取用户反思时间失败end_user_id: {end_user_id}, 错误: {str(e)}")
return None
def update_end_user_reflection_time(self, end_user_id: str) -> bool:
"""
Update the reflection time of end users to the current time
@@ -189,7 +185,7 @@ class WorkspaceAppService:
"""
try:
from datetime import datetime
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first()
if end_user:
end_user.reflection_time = datetime.now()
@@ -207,7 +203,7 @@ class WorkspaceAppService:
class MemoryReflectionService:
"""Memory reflection service category"""
def __init__(self,db: Session = Depends(get_db)):
self.db=db
@@ -252,22 +248,22 @@ class MemoryReflectionService:
"end_user_id": end_user_id,
"config_data": config_data
}
async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]:
"""
Starting Reflection from Configuration Data
Args:
config_data: Configure data dictionary, including reflective configuration information
end_user_id: end_user_id
Returns:
Reflect on the execution results
"""
try:
config_id = config_data.get("config_id")
api_logger.info(f"从配置数据启动反思config_id: {config_id}, end_user_id: {end_user_id}")
if not config_data.get("enable_self_reflexion", False):
return {
@@ -277,7 +273,7 @@ class MemoryReflectionService:
"end_user_id": end_user_id,
"config_data": config_data
}
config_data_id=config_data['config_id']
reflection_config=WorkspaceAppService(self.db)._get_memory_config(config_data_id)
@@ -290,7 +286,7 @@ class MemoryReflectionService:
# 检查是否需要执行反思
should_execute = False
hours_diff = 0
if current_reflection_time is None:
# 首次执行反思
should_execute = True
@@ -302,11 +298,11 @@ class MemoryReflectionService:
reflection_time = datetime.fromisoformat(current_reflection_time)
else:
reflection_time = current_reflection_time
current_time = datetime.now()
time_diff = current_time - reflection_time
hours_diff = int(time_diff.total_seconds() / 3600)
# 检查是否达到反思周期
if hours_diff >= iteration_period:
should_execute = True
@@ -316,7 +312,7 @@ class MemoryReflectionService:
except (ValueError, TypeError) as e:
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
should_execute = True
if should_execute:
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
# 3. 执行反思引擎
@@ -349,7 +345,7 @@ class MemoryReflectionService:
"next_reflection_in_hours": iteration_period - hours_diff
}
except Exception as e:
config_id = config_data.get("config_id", "unknown")
api_logger.error(f"启动反思失败config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
@@ -360,7 +356,7 @@ class MemoryReflectionService:
"end_user_id": end_user_id,
"config_data": config_data
}
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
"""Create reflective configuration objects from configuration data"""
@@ -368,12 +364,12 @@ class MemoryReflectionService:
if reflexion_range_value is None or reflexion_range_value == "":
reflexion_range_value = "partial"
reflexion_range = ReflectionRange(reflexion_range_value)
baseline_value = config_data.get("baseline")
if baseline_value is None or baseline_value == "":
baseline_value = "TIME"
baseline = ReflectionBaseline(baseline_value)
# iteration_period =
iteration_period = config_data.get("iteration_period", 24)
if isinstance(iteration_period, str):
@@ -381,7 +377,6 @@ class MemoryReflectionService:
iteration_period = int(iteration_period)
except (ValueError, TypeError):
iteration_period = 24 # 默认24小时
return ReflectionConfig(
enabled=config_data.get("enable_self_reflexion", False),
iteration_period=str(iteration_period), # ReflectionConfig期望字符串

View File

@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
if not params.reflection_model_id:
params.reflection_model_id = params.llm_id
if not params.emotion_model_id:
params.emotion_model_id = params.llm_id
config = MemoryConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
@@ -177,11 +183,11 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 将 ORM 对象转换为字典列表
data_list = []
for config in configs:
for config, scene_name in results:
# 安全地转换 user_id 为 int
config_id_old = None
if config.config_id_old:
@@ -203,6 +209,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,
@@ -628,10 +636,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = f"{m}分钟"
latest_relative = "一会"
else:
h = int(m // 60)
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
latest_relative = "较早前"
except Exception:
pass

View File

@@ -347,7 +347,9 @@ class ModelConfigService:
"is_public": model_data.is_public,
"is_composite": True
}
if "load_balance_strategy" in model_data.model_fields_set:
model_config_data["load_balance_strategy"] = model_data.load_balance_strategy
model = ModelConfigRepository.create(db, model_config_data)
db.flush()
@@ -380,7 +382,7 @@ class ModelConfigService:
for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = model_data.type
request_type = existing_model.type
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
@@ -391,12 +393,14 @@ class ModelConfigService:
# 更新基本信息
existing_model.name = model_data.name
existing_model.type = model_data.type
# existing_model.type = model_data.type
existing_model.logo = model_data.logo
existing_model.description = model_data.description
existing_model.config = model_data.config
existing_model.is_active = model_data.is_active
existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联
existing_model.api_keys.clear()
@@ -453,9 +457,11 @@ class ModelApiKeyService:
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
@staticmethod
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> List[ModelApiKey]:
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[
list[Any], list[Any]]:
"""根据provider为多个ModelConfig创建API Key"""
created_keys = []
failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
@@ -501,10 +507,9 @@ class ModelApiKeyService:
test_message="Hello"
)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
# 记录验证失败的模型,但不抛出异常
failed_models.append(model_name)
continue
# 创建API Key
api_key_data = ModelApiKeyCreate(
@@ -526,7 +531,7 @@ class ModelApiKeyService:
for key in created_keys:
db.refresh(key)
return created_keys
return created_keys, failed_models
@staticmethod
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
@@ -684,6 +689,9 @@ class ModelBaseService:
@staticmethod
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
if existing:
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
model_base = ModelBaseRepository.create(db, data.model_dump())
db.commit()
db.refresh(model_base)

View File

@@ -280,14 +280,22 @@ class MultiAgentOrchestrator:
# 4. 提取子 Agent 的 conversation_id用于多轮对话
sub_conversation_id = None
total_tokens = 0
if isinstance(results, dict):
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
# 提取 token 信息
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
elif isinstance(results, list) and results:
for item in results:
if "result" in item:
sub_conversation_id = item["result"].get("conversation_id")
if sub_conversation_id:
break
# 累加每个子 Agent 的 token
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
logger.info(
"多 Agent 任务完成",
@@ -301,9 +309,15 @@ class MultiAgentOrchestrator:
return {
"message": final_result,
"conversation_id": sub_conversation_id,
"mode": OrchestrationMode.SUPERVISOR,
"elapsed_time": elapsed_time,
"strategy": routing_decision.get("collaboration_strategy", "single"),
"sub_results": results
"sub_results": results,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
except Exception as e:
@@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator:
return {
"message": result.get("response", ""),
"conversation_id": result.get("conversation_id"),
"mode": OrchestrationMode.COLLABORATION,
"elapsed_time": elapsed_time,
"strategy": "collaboration",
"active_agent": result.get("active_agent"),
"sub_results": result
"sub_results": result,
"usage": result.get("usage")
}
except Exception as e:

View File

@@ -1,5 +1,6 @@
"""多 Agent 配置管理服务"""
import uuid
import json
from typing import Optional, List, Tuple, Any, Annotated
from fastapi import Depends
@@ -427,6 +428,23 @@ class MultiAgentService:
memory=getattr(request, 'memory', True) # 记忆功能参数
)
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=result.get("message", ""),
app_id=app_id,
user_id=request.user_id,
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
return result
async def run_stream(
@@ -451,11 +469,14 @@ class MultiAgentService:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
if not config.is_active:
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
full_content = ""
total_tokens = 0
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=request.message,
@@ -468,7 +489,88 @@ class MultiAgentService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
if "sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
except:
pass
else:
yield event
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=full_content,
app_id=app_id,
user_id=request.user_id,
meta_data={
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)
async def _save_conversation_message(
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
"""保存会话消息
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复消息
meta_data: 元数据(包括 token 消耗)
app_id: 应用ID
user_id: 用户ID
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=user_message
)
conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message,
meta_data=meta_data
)
logger.debug(
"保存多 Agent 会话消息",
extra={
"conversation_id": conversation_id,
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
except Exception as e:
logger.warning("保存会话消息失败", extra={"error": str(e)})
# def add_sub_agent(
# self,

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
import os
import re
import uuid
from typing import Any, AsyncGenerator
@@ -16,9 +17,10 @@ from app.models.prompt_optimizer_model import (
PromptOptimizerSession,
RoleType
)
from app.repositories.model_repository import ModelConfigRepository
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerSessionRepository
PromptOptimizerSessionRepository,
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
@@ -28,6 +30,8 @@ logger = get_business_logger()
class PromptOptimizerService:
def __init__(self, db: Session):
self.db = db
self.optim_repo = PromptOptimizerSessionRepository(self.db)
self.release_repo = PromptReleaseRepository(self.db)
def get_model_config(
self,
@@ -78,10 +82,12 @@ class PromptOptimizerService:
Returns:
PromptOptimzerSession: The newly created prompt optimization session.
"""
session = PromptOptimizerSessionRepository(self.db).create_session(
session = self.optim_repo.create_session(
tenant_id=tenant_id,
user_id=user_id
)
self.db.commit()
self.db.refresh(session)
return session
def get_session_message_history(
@@ -106,7 +112,7 @@ class PromptOptimizerService:
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
- content (str): The content of the message.
"""
history = PromptOptimizerSessionRepository(self.db).get_session_history(
history = self.optim_repo.get_session_history(
session_id=session_id,
user_id=user_id
)
@@ -168,7 +174,8 @@ class PromptOptimizerService:
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance
api_config: ModelApiKey = model_config.api_keys[0]
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
api_config: ModelApiKey = api_keys[0] if api_keys else None
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
@@ -176,11 +183,12 @@ class PromptOptimizerService:
base_url=api_config.api_base
), type=ModelType(model_config.type))
try:
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render()
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
@@ -295,4 +303,165 @@ class PromptOptimizerService:
role=role,
content=content
)
self.db.commit()
self.db.refresh(message)
return message
def save_prompt(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
title: str,
prompt: str
) -> dict:
"""
Create and save a new prompt release for a given session.
Args:
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
session_id (uuid.UUID): The ID of the session to associate with this prompt.
title (str): The title of the prompt release.
prompt (str): The content of the prompt.
Returns:
dict: A dictionary containing:
- id (UUID): The unique ID of the created prompt release.
- session_id (UUID): The session ID linked to the release.
- title (str): The title of the prompt.
- prompt (str): The prompt content.
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
Raises:
BusinessException: If a prompt release already exists for the given session.
"""
session = self.optim_repo.get_session_by_id(session_id)
if session is None or session.tenant_id != tenant_id:
raise BusinessException(
"Session does not exist or the current user has no access",
BizCode.BAD_REQUEST
)
if self.release_repo.get_prompt_by_session_id(session_id):
raise BusinessException(
"A release already exists for the current session",
BizCode.BAD_REQUEST
)
prompt_obj = self.release_repo.create_prompt_release(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt
)
self.db.commit()
self.db.refresh(prompt_obj)
return {
"id": prompt_obj.id,
"session_id": prompt_obj.session_id,
"title": prompt_obj.title,
"prompt": prompt_obj.prompt,
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
}
def delete_prompt(
self,
tenant_id: uuid.UUID,
prompt_id: uuid.UUID
) -> None:
"""
Soft delete a prompt release by prompt_id.
Args:
tenant_id (uuid.UUID): Tenant identifier.
prompt_id (uuid.UUID): Prompt identifier.
Raises:
BusinessException: If the prompt does not exist or already deleted.
"""
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
if not prompt_obj or prompt_obj.is_delete:
raise BusinessException(
"Prompt does not exist or has already been deleted",
BizCode.NOT_FOUND
)
if prompt_obj.tenant_id != tenant_id:
raise BusinessException(
"No permission to delete this prompt",
BizCode.FORBIDDEN
)
self.release_repo.soft_delete_prompt(prompt_obj)
self.db.commit()
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
def get_release_list(
self,
tenant_id: uuid.UUID,
page: int,
page_size: int,
filter_keyword: str | None = None
) -> dict[str, int | list[Any]]:
"""
Get paginated list of prompt releases with optional filter.
Args:
tenant_id (uuid.UUID): Tenant identifier.
page (int): Page number (starting from 1).
page_size (int): Number of items per page.
filter_keyword (str | None): Optional keyword to filter by title.
Returns:
dict: Contains total count, pagination info, and list of releases.
"""
offset = (page - 1) * page_size
# Get total count and releases based on filter
if filter_keyword:
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
releases = self.release_repo.search_prompts_paginated(
tenant_id=tenant_id,
keyword=filter_keyword,
offset=offset,
limit=page_size
)
else:
total = self.release_repo.count_prompts(tenant_id)
releases = self.release_repo.get_prompts_paginated(
tenant_id=tenant_id,
offset=offset,
limit=page_size
)
items = []
for release in releases:
# Get first user message from session
first_message = self.optim_repo.get_first_user_message(
session_id=release.session_id
)
items.append({
"id": release.id,
"title": release.title,
"prompt": release.prompt,
"created_at": int(release.created_at.timestamp() * 1000),
"first_message": first_message
})
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
if filter_keyword:
log_msg += f", filter='{filter_keyword}'"
logger.info(log_msg)
result = {
"page": {
"total": total,
"page": page,
"page_size": page_size,
"hasnext": page * page_size < total
},
"keyword": filter_keyword,
"items": items
}
return result

Some files were not shown because too many files have changed in this diff Show More