Compare commits

..

385 Commits

Author SHA1 Message Date
Ke Sun
3c99fb116c Merge pull request #568 from SuanmoSuanyangTechnology/release/v0.2.7
Release/v0.2.7
2026-03-13 22:56:02 +08:00
Ke Sun
e7e136036c Merge pull request #566 from SuanmoSuanyangTechnology/fix/time
[changes] Note the time zone
2026-03-13 21:57:02 +08:00
lanceyq
ca84fc6c9d [changes] Note the time zone 2026-03-13 21:55:10 +08:00
Mark
32d612fbeb Merge pull request #565 from SuanmoSuanyangTechnology/fix/version_027
docs(version)
2026-03-13 18:53:13 +08:00
Timebomb2018
9ce3a881f3 docs(version): version description 2026-03-13 18:51:33 +08:00
yujiangping
1635f9dbef fix:cancel api 2026-03-13 18:10:06 +08:00
yujiangping
75e36173cd Merge branch 'release/v0.2.7' into feature/tool_yjp 2026-03-13 17:35:14 +08:00
yujiangping
8097f227ca feat(mcp market): Add configuration update notification and refactor MCP list fetching
- Add marketConfigUpdated i18n message in English and Chinese translations
- Replace inline MCP list fetching logic with fetchMcpList function call
- Improve code maintainability by centralizing MCP list retrieval logic
- Ensure consistent handling of MCP list state across configuration updates
2026-03-13 17:34:18 +08:00
yujiangping
fd1debe681 fix:knowbase view 2026-03-13 17:24:17 +08:00
yujiangping
39f3a85bb1 Merge branch 'feature/tool_yjp' into release/v0.2.7 2026-03-13 16:43:35 +08:00
yujiangping
098a2e54ae fix:loading 2026-03-13 16:41:55 +08:00
Mark
d575478b53 Merge pull request #557 from SuanmoSuanyangTechnology/fix/app_027
fix(mcp)
2026-03-13 16:21:40 +08:00
Timebomb2018
d4f2094ee0 fix(mcp): The token configuration modification of MCP Market needs to be verified. 2026-03-13 16:18:46 +08:00
Mark
065f8db2f7 Merge pull request #552 from SuanmoSuanyangTechnology/fix/app_027
fix(mcp)
2026-03-13 14:15:49 +08:00
yujiangping
0ac7f83726 Merge branch 'feature/tool_yjp' into release/v0.2.7 2026-03-13 14:08:21 +08:00
yujiangping
d03473da10 fix:input disable 2026-03-13 14:07:41 +08:00
Timebomb2018
dac1c01a2c fix(mcp): bug fix 2026-03-13 14:05:27 +08:00
Mark
83015a3404 Merge pull request #550 from SuanmoSuanyangTechnology/fix/app_027
fix(mcp)
2026-03-13 11:10:22 +08:00
Timebomb2018
b88e9c5f5e fix(mcp): The MCP Square can obtain a maximum of 100 MCP services. 2026-03-13 11:07:32 +08:00
yujiangping
8380a8a811 Merge branch 'release/v0.2.7' into feature/tool_yjp 2026-03-13 10:36:10 +08:00
yujiangping
6c69181290 fix:reset 2026-03-13 10:33:11 +08:00
Mark
0f36c5c872 Merge pull request #547 from SuanmoSuanyangTechnology/fix/RAG-show
[changes] Field standardization
2026-03-12 19:52:57 +08:00
lanceyq
6a67f028ce [changes] Set constants 2026-03-12 19:50:32 +08:00
Mark
5d82786c20 Merge pull request #548 from SuanmoSuanyangTechnology/fix/app_027
fix(mcp square)
2026-03-12 19:42:38 +08:00
Timebomb2018
e368f1c1d6 fix(mcp square): Do not obtain the mcp service when the token is empty. 2026-03-12 19:40:14 +08:00
lanceyq
572ce7f9ec [changes] Field standardization 2026-03-12 19:13:24 +08:00
Ke Sun
4859ab3ba7 Merge pull request #544 from SuanmoSuanyangTechnology/fix/RAG-show
[add] RAG storage displays the page effect
2026-03-12 18:40:17 +08:00
yingzhao
983b5f5087 Merge pull request #545 from SuanmoSuanyangTechnology/fix/v0.2.7_zy
feat(web): rag content  add page
2026-03-12 18:38:29 +08:00
zhaoying
75b87955dd feat(web): rag content add page 2026-03-12 18:36:49 +08:00
lanceyq
110de0afbc [add] RAG storage displays the page effect 2026-03-12 18:35:09 +08:00
Mark
2c074cd5c1 Merge pull request #543 from SuanmoSuanyangTechnology/fix/app_027
fix(app)
2026-03-12 18:20:33 +08:00
Timebomb2018
73e51a9b0b fix(app): Bug fixes for application import and export 2026-03-12 17:36:42 +08:00
yingzhao
3a47039919 Merge pull request #542 from SuanmoSuanyangTechnology/fix/v0.2.7_zy
fix(web): upload cancel add refresh
2026-03-12 17:21:54 +08:00
zhaoying
2961ea4e44 fix(web): upload cancel add refresh 2026-03-12 17:20:16 +08:00
yingzhao
af2ffc9737 Merge pull request #541 from SuanmoSuanyangTechnology/fix/v0.2.7_zy
fix(web): update i18n
2026-03-12 17:02:37 +08:00
zhaoying
d7911244fc fix(web): update i18n 2026-03-12 17:00:30 +08:00
Ke Sun
6029a5a9a8 Merge pull request #539 from SuanmoSuanyangTechnology/fix/RAG-field
[changes] Remove the non-existent "storage_type"
2026-03-12 13:57:04 +08:00
lanceyq
71d9ae15a1 [changes] Remove the non-existent "storage_type" 2026-03-12 13:53:47 +08:00
Ke Sun
4706ea59fe Merge pull request #536 from SuanmoSuanyangTechnology/fix/task
[changes] Time zone modification
2026-03-12 12:03:33 +08:00
lanceyq
5774a95f61 [changes] Time zone modification 2026-03-12 12:00:56 +08:00
yujiangping
5db2c5092e Merge branch 'release/v0.2.7' of github.com:SuanmoSuanyangTechnology/MemoryBear into release/v0.2.7 2026-03-11 18:24:55 +08:00
yujiangping
59618457df feat(web): add search functionality and empty states to MCP market
- Add search input with debouncing (500ms) to filter MCP services by keywords
- Implement server-side search via keywords parameter in getMarketMCPs API call
- Add new i18n strings for empty states: marketNoData, marketNoDataDesc, marketNoSearchResult, marketNoSearchResultDesc
- Replace client-side filtering with server-side search for better performance
- Update Empty component display to show different messages for no data vs no search results
- Remove BodyWrapper component and implement custom empty state handling
- Add searchTimerRef to manage debounce timer lifecycle
- Update loadMore callback to include searchKeyword parameter for pagination consistency
- Add allowClear prop to search input for better UX
- Remove conditional rendering of search input to keep it always visible
2026-03-11 18:24:46 +08:00
Mark
8d053c97a7 Merge pull request #534 from SuanmoSuanyangTechnology/fix/mcp_027
fix(tool)
2026-03-11 17:25:15 +08:00
Timebomb2018
a3e6f67ff7 fix(tool): The MCP tool checks for duplicate additions from the main screen and performs a test before adding. 2026-03-11 17:19:07 +08:00
yujiangping
01da2e3eee Merge branch 'feature/tool_yjp' into release/v0.2.7 2026-03-11 17:13:50 +08:00
yujiangping
168cce1678 feat(web): improve MCP market UI responsiveness and add refresh after service addition
- Change getMarketTools parameter type from Query to optional Record for flexibility
- Rename marketConfig i18n key to marketConfigBtn for clarity and consistency
- Add handleRefreshAfterAdd function to refresh MCP list after successful service addition
- Update grid layout to use auto-fill responsive columns instead of fixed 3-column layout
- Disable Add button for services already in database to prevent duplicate additions
- Connect McpServiceModal refresh callback to handleRefreshAfterAdd for cache invalidation
- Improves user experience by automatically updating market list after adding services
2026-03-11 17:11:16 +08:00
yingzhao
7240dfe793 Merge pull request #533 from SuanmoSuanyangTechnology/fix/v0.2.7_zy
feat(web): model api key add request abort
2026-03-11 15:17:50 +08:00
zhaoying
b9340ba02d feat(web): model api key add request abort 2026-03-11 15:16:02 +08:00
Ke Sun
6a1b8d3ee3 Merge pull request #532 from SuanmoSuanyangTechnology/develop
Develop
2026-03-10 20:17:02 +08:00
Ke Sun
f1207dc8b9 Merge pull request #531 from SuanmoSuanyangTechnology/feature/pruning-scene
[add] Different scenarios achieve different pruning effects.
2026-03-10 20:16:08 +08:00
lanceyq
86c51559bb [add] Remove unused protected_ids; cap delete_target by actual deletable count. 2026-03-10 19:06:50 +08:00
lanceyq
8b0f806079 [add] Different scenarios achieve different pruning effects. 2026-03-10 19:00:30 +08:00
Eternity
99e94b3567 feat(workflow,app): add MIME-based file handling and HTTP response files 2026-03-10 18:28:16 +08:00
Mark
cfd5c1bc93 Merge pull request #530 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(file)
2026-03-10 18:09:32 +08:00
Timebomb2018
45d9e45346 fix(file): S3 file storage resolves the issue of inconsistent end_point and region. 2026-03-10 18:05:34 +08:00
Ke Sun
fcb3845543 Merge pull request #528 from SuanmoSuanyangTechnology/feature/pruning-optimize
Feature/pruning optimize
2026-03-10 17:37:43 +08:00
lanceyq
97eabc0c36 [add] Remove hardcoding 2026-03-10 17:25:32 +08:00
lanceyq
5328163973 Merge branch 'feature/pruning-optimize' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/pruning-optimize 2026-03-10 17:16:25 +08:00
lanceyq
7ff9dfee8c [changes] Remove hardcoded content 2026-03-10 17:14:50 +08:00
Mark
1e1675ec12 Merge pull request #527 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(app)
2026-03-10 16:22:06 +08:00
Timebomb2018
f941541304 fix(app): Workflow import verification 2026-03-10 16:18:22 +08:00
lanceyq
3f7083c5b3 [add] Modify reserved words to avoid being affected by the threshold. 2026-03-10 16:16:05 +08:00
Mark
e81faebf69 [add] migration script 2026-03-10 14:51:48 +08:00
Ke Sun
8a4d58c520 Merge pull request #524 from SuanmoSuanyangTechnology/feature/details-memory
Feature/details memory
2026-03-10 14:42:18 +08:00
yingzhao
2ac29ee89c Merge pull request #526 from SuanmoSuanyangTechnology/feature/app_zy
feat(web): app add export & import
2026-03-10 14:24:52 +08:00
yingzhao
252cdcd6f5 Merge pull request #525 from SuanmoSuanyangTechnology/feature/memory_zy
Feature/memory zy
2026-03-10 14:24:17 +08:00
zhaoying
16e2c95965 feat(web): app add export & import 2026-03-10 14:23:05 +08:00
lanceyq
10560fb34c [changes] Clearly stipulated, the conditions for raising an error 2026-03-10 13:55:53 +08:00
lanceyq
58aa60ca0e [add] Change to "Body - json" format and pass as parameters 2026-03-10 13:39:24 +08:00
zhaoying
d24b186d3e Merge branch 'feature/memory_zy' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/memory_zy 2026-03-10 13:37:42 +08:00
zhaoying
b4e81615b1 feat(web): rag user memery add refresh 2026-03-10 13:35:52 +08:00
lanceyq
424d2033ea [add] Added an interface for refreshing RAG storage image data 2026-03-10 12:11:13 +08:00
lanceyq
fd556f9b00 [add] Generate user summaries and memory insights using Jinja2 tags 2026-03-10 11:51:17 +08:00
lanceyq
e2f5fa87b1 [add] Add cache to RAG storage 2026-03-10 11:41:09 +08:00
Mark
e4a2bd3b9b Merge pull request #522 from SuanmoSuanyangTechnology/fix/bug-patch
feat(workspace, app, agent): add duplicate name validation and restrict model/memory config on agent publish
2026-03-10 11:31:54 +08:00
Mark
e3ada17a78 Merge pull request #523 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(app)
2026-03-10 11:31:14 +08:00
Timebomb2018
3e5a7adfe4 feat(app): Application (agent, workflow) import/export 2026-03-10 11:28:52 +08:00
Timebomb2018
3237f4cd6e feat(app): Application (agent, workflow) import/export 2026-03-10 11:27:28 +08:00
Timebomb2018
beea826377 feat(app): Application (agent, workflow) import/export 2026-03-10 11:17:52 +08:00
Eternity
7cdbbefc64 feat(workspace, app, agent): add duplicate name validation and restrict model/memory config on agent publish 2026-03-10 10:59:59 +08:00
yujiangping
18780622b3 Merge branch 'feature/tool_yjp' into develop 2026-03-09 19:11:13 +08:00
yujiangping
f405ac4d84 fix:next button 2026-03-09 19:10:39 +08:00
Ke Sun
9fe47e2fb2 fix(memory_agent): handle draft run without current release
- Add TODO comment to verify end_user sources (chat, draft, apikey)
- Comment out release validation check to support draft run mode
- Add TODO note explaining temporary fix for draft execution
- Handle null current_release_id in result by returning None instead of failing
- Improve import formatting for MemoryConfig model import statement
- Allow configuration retrieval when app has no published release
2026-03-09 19:07:09 +08:00
lanceyq
e4aaa18f61 [changes] User summaries stored in RAG, generation of memory insights 2026-03-09 18:50:32 +08:00
yingzhao
5c3d9717dd Merge pull request #521 from SuanmoSuanyangTechnology/feature/notes_zy
feat(web): 注释节点
2026-03-09 17:36:52 +08:00
zhaoying
ac86bbd60c feat(web): 调整便签节点位置 2026-03-09 17:35:56 +08:00
zhaoying
33d12c43b2 feat(web): 注释节点 2026-03-09 17:30:43 +08:00
Ke Sun
107c676185 Merge pull request #520 from SuanmoSuanyangTechnology/feature/interest-exists
Feature/interest exists
2026-03-09 17:10:19 +08:00
yujiangping
0f221b7ee6 fix:loading 2026-03-09 16:45:48 +08:00
yujiangping
e1939ef472 feat(web): internationalize MCP market UI strings
- Add 19 new i18n keys for market-related UI text in English and Chinese
- Replace hardcoded Chinese strings with i18n translations in Market.tsx
- Update market refresh success message to use i18n key
- Internationalize market selection, configuration, and service browsing UI
- Support multi-language display for market status tags and action buttons
2026-03-09 16:31:45 +08:00
lanceyq
5438d35f17 [add] Specify the error types and clearly define the downgrade conditions 2026-03-09 16:19:55 +08:00
yujiangping
9c26d1f4c8 Merge branch 'develop' into feature/tool_yjp 2026-03-09 16:11:37 +08:00
yujiangping
4c2b31f31f feat(web): add MCP market database tracking and refresh status messages
- Add i18n translations for refresh success and failure messages in English and Chinese
- Track MCP tools already stored in database with inDatabase flag in Market component
- Display "已入库" (In Database) tag alongside activation status for MCPs
- Import getTools API to fetch full tool list for database status comparison
- Add market metadata fields (source_channel, market_id, market_config_id, mcp_service_id) to tool items when adding from market
- Preserve market source information through McpServiceModal when saving tools
- Update ToolItem type to include market tracking fields in config_data
- Improve MCP card layout to properly display multiple status tags
2026-03-09 15:36:49 +08:00
lanceyq
4f88a13256 Merge branch 'feature/interest-exists' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/interest-exists 2026-03-09 14:59:36 +08:00
lanceyq
21ae448ed7 [add] Throw out explicit error messages; Using the CST time zone 2026-03-09 14:58:03 +08:00
lanceyq
50466124c8 [add] Verification of the existence of interest distribution 2026-03-09 14:57:22 +08:00
Ke Sun
ece88a3879 Merge pull request #518 from SuanmoSuanyangTechnology/feature/timer-shaft
Feature/timer shaft
2026-03-09 14:44:46 +08:00
Mark
cedc4a92cc Merge pull request #515 from SuanmoSuanyangTechnology/feature/workflow-notes
feat(workflow): add support for notes nodes
2026-03-09 14:14:07 +08:00
Ke Sun
c8065b0c60 feat(implicit-emotions): add Redis resilience and connection pooling
- Replace single Redis client with connection pool for better concurrency and auto-reconnection
- Add graceful degradation when Redis is unavailable (None handling in get_users_needing_refresh)
- Add RedisError exception handling with fallback to process all users on mget failures
- Add type hints (Optional[redis.StrictRedis]) to Redis client parameters
- Add health check and socket timeout configuration to connection pool
- Add logging for Redis connection failures and degradation events
- Reorganize imports alphabetically for consistency across both files
- Update get_sync_redis_client to validate connection with ping() before returning
2026-03-09 14:12:53 +08:00
lanceyq
476632294f [changes] Remove the "worker-ondemand" queue 2026-03-09 14:02:23 +08:00
lanceyq
349d46e043 [changes] Add restriction words to avoid the "implicit" and "emotional" content from being mistakenly pruned. 2026-03-09 11:26:54 +08:00
Ke Sun
00e0201bf9 Merge pull request #517 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
2026-03-09 10:56:39 +08:00
Ke Sun
b9ebe22df1 Merge pull request #516 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
2026-03-09 10:53:16 +08:00
Eternity
389dd8d402 feat(workflow): support resizing comment nodes, add theme and author display toggle 2026-03-09 03:21:04 +08:00
Eternity
966bd8528d feat(workflow): simplify node converter registry 2026-03-09 03:08:44 +08:00
Eternity
8f789d47a2 feat(workflow): add support for notes nodes 2026-03-09 03:00:27 +08:00
lanceyq
94a40e49a0 [add] Throw out explicit error messages; Using the CST time zone 2026-03-07 17:07:38 +08:00
lanceyq
8429279eea [add] Verification of the existence of interest distribution 2026-03-07 16:55:06 +08:00
lanceyq
cef14cda9e [add] Standardize time zones; Reuse a single Redis client; Use "mget" for batch writing requests 2026-03-07 16:36:24 +08:00
lanceyq
c14f067afb [add] The "update-implicit-emotions-storage" task uses the timeline to filter the updated data users. 2026-03-07 16:23:59 +08:00
yingzhao
6c8dca6379 Merge pull request #512 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): change mousewheel factor
2026-03-07 15:25:44 +08:00
zhaoying
819d205166 fix(web): change mousewheel factor 2026-03-07 15:23:56 +08:00
yingzhao
9e17f65eda Merge pull request #511 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): jinja2 editor bugfix
2026-03-07 14:53:26 +08:00
zhaoying
7373f68172 fix(web): jinja2 editor bugfix 2026-03-07 14:52:00 +08:00
Mark
0999bd30d7 Merge pull request #510 from SuanmoSuanyangTechnology/fix/bug-patch
fix(workflow): fix compatibility issues when importing workflows from dify
2026-03-07 14:48:26 +08:00
Eternity
f01185a7fc fix(workflow): fix compatibility issues when importing workflows from dify 2026-03-07 14:44:00 +08:00
yingzhao
7cd7303754 Merge pull request #509 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): add notes node; jinja2 editor bugfix
2026-03-07 14:42:15 +08:00
zhaoying
d19fec2155 fix(web): add notes node; jinja2 editor bugfix 2026-03-07 14:40:43 +08:00
lanceyq
2612abc9d0 [add] Create a Celery task for checking the existence of the "implicit_emotions" data 2026-03-07 13:56:15 +08:00
Mark
d080b44ac3 Merge branch 'release/v0.2.6' into develop
* release/v0.2.6:
  fix(web): ontology class default tag bugfix
  fix(version): Version 0.2.6 Release Notes
  fix(web): chat file delete bugfix
  feat: support model load balancing and add message_id to API responses
  feat: support model load balancing and add message_id to API responses
  [changes] Work space isolation
  [add] Recently, memory activities have adopted Redis caching.
  [changes] Work space isolation
  [add] Recently, memory activities have adopted Redis caching.
  fix(web): upload add loading
  [changes] The enumeration check has been changed to a string.
  [changes] The enumeration check has been changed to a string.
  feat(web): http-request add headers variable
  fix(workflow): ensure file messages are written to messages in non-stream mode
  fix(workflow): fix Dify compatibility issues
  [changes] Memory write completion active failure interest cache
  feat(workflow): support multimodal context
  [changes] AI review and correction of code
  [add] Semantic pruning is unified with the ontology engineering scenario.
  feat(chat): add message_id field to chat API response
2026-03-07 11:09:39 +08:00
Mark
df18868888 Merge pull request #507 from SuanmoSuanyangTechnology/fix/version_026
fix(version)
2026-03-07 11:08:30 +08:00
yingzhao
4438b08560 Merge pull request #508 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): ontology class default tag bugfix
2026-03-07 10:35:33 +08:00
zhaoying
1029f94669 fix(web): ontology class default tag bugfix 2026-03-07 10:33:32 +08:00
Timebomb2018
0a3acf446d fix(version): Version 0.2.6 Release Notes 2026-03-07 04:19:35 +02:00
Mark
c01ad5a19e Merge pull request #498 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(mcp)
2026-03-07 10:15:18 +08:00
Mark
5a7723553c Merge pull request #505 from SuanmoSuanyangTechnology/fix/bug-patch
feat: support model load balancing and add message_id to API responses
2026-03-07 10:11:20 +08:00
yingzhao
975844eccf Merge pull request #506 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): chat file delete bugfix
2026-03-06 19:45:37 +08:00
zhaoying
865ad31f2f fix(web): chat file delete bugfix 2026-03-06 19:44:34 +08:00
Eternity
b756f0c86c feat: support model load balancing and add message_id to API responses 2026-03-06 19:42:40 +08:00
Eternity
3e5f6176af feat: support model load balancing and add message_id to API responses 2026-03-06 19:29:31 +08:00
Mark
ab5b165dc2 Merge pull request #504 from SuanmoSuanyangTechnology/feature/activity-cache
[add] Recently, memory activities have adopted Redis caching.
2026-03-06 18:48:26 +08:00
lanceyq
f9393c2f63 Merge branch 'feature/activity-cache' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/activity-cache 2026-03-06 18:39:28 +08:00
lanceyq
aa6638424c [changes] Work space isolation 2026-03-06 18:39:21 +08:00
lanceyq
834387e254 [add] Recently, memory activities have adopted Redis caching. 2026-03-06 18:39:21 +08:00
lanceyq
9caa986c80 [changes] Work space isolation 2026-03-06 18:38:23 +08:00
yingzhao
72b84dfc8f Merge pull request #503 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): upload add loading
2026-03-06 18:32:56 +08:00
lanceyq
af10195025 [add] Recently, memory activities have adopted Redis caching. 2026-03-06 18:32:24 +08:00
zhaoying
22382423ad fix(web): upload add loading 2026-03-06 18:30:40 +08:00
Ke Sun
0f80c67cbd Merge pull request #502 from SuanmoSuanyangTechnology/fix/interest-distribution
Fix/interest distribution
2026-03-06 17:36:21 +08:00
lanceyq
aa6473c1c7 Merge branch 'fix/interest-distribution' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/interest-distribution 2026-03-06 17:35:00 +08:00
lanceyq
cde61cb6ac [changes] The enumeration check has been changed to a string. 2026-03-06 17:34:52 +08:00
lanceyq
b1368997c2 [changes] The enumeration check has been changed to a string. 2026-03-06 17:33:12 +08:00
yingzhao
ec7dc448c1 Merge pull request #501 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-03-06 17:29:09 +08:00
Ke Sun
254147265e Merge pull request #497 from SuanmoSuanyangTechnology/fix/bug-patch
feat(workflow,chat): support multimodal context and add message_id to chat API response; fix Dify compatibility issues
2026-03-06 17:28:36 +08:00
zhaoying
479bba9a4e feat(web): http-request add headers variable 2026-03-06 17:27:43 +08:00
Ke Sun
cfb39a6baa Merge pull request #500 from SuanmoSuanyangTechnology/fix/interest-distribution
[changes] Memory write completion active failure interest cache
2026-03-06 17:26:18 +08:00
Eternity
05c9ed1450 fix(workflow): ensure file messages are written to messages in non-stream mode 2026-03-06 17:26:03 +08:00
Eternity
f53633a8b8 fix(workflow): fix Dify compatibility issues 2026-03-06 17:17:29 +08:00
yingzhao
f56bc0f85a Merge pull request #499 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-03-06 17:17:08 +08:00
lanceyq
63882e9391 [changes] Memory write completion active failure interest cache 2026-03-06 17:16:00 +08:00
zhaoying
3c4dfb868f fix(web): knowledge-retrieval node's config ignore name & description key 2026-03-06 17:15:32 +08:00
Timebomb2018
9600d687fa fix(mcp): Obtain the MCP tool information to complete the channel information 2026-03-06 17:15:12 +08:00
Ke Sun
cae9105b8d Merge pull request #489 from SuanmoSuanyangTechnology/feature/scene-uniformity
[add] Semantic pruning is unified with the ontology engineering scena…
2026-03-06 16:55:20 +08:00
Ke Sun
41a0036bf6 chore(migrations): add MCP tool config source tracking fields
- Add source_channel column to mcp_tool_configs with 'self_hosted' default
- Add market_id column to track marketplace source reference
- Add market_config_id column to store marketplace configuration reference
- Add mcp_service_id column to identify MCP service instances
- Enable tracking of tool origin and marketplace integration metadata
2026-03-06 16:52:27 +08:00
yingzhao
2c9401ccfb Merge pull request #496 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): model status bugfix
2026-03-06 16:40:55 +08:00
Ke Sun
08e4ad6a7c Merge pull request #495 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(mcp)
2026-03-06 16:40:03 +08:00
zhaoying
2b0dedc81c fix(web): model status bugfix 2026-03-06 16:38:11 +08:00
Ke Sun
314e6e29d5 Merge pull request #494 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
2026-03-06 16:37:23 +08:00
Ke Sun
16b87de0df Merge branch 'develop' into release/v0.2.6 2026-03-06 16:37:02 +08:00
Ke Sun
8c3af7f4ff fix(config): update default Redis DB numbers for Celery isolation
- Change REDIS_DB_CELERY_BROKER default from 1 to 3
- Change REDIS_DB_CELERY_BACKEND default from 2 to 4
- Add documentation comments explaining DB isolation strategy
- Prevent task interference when multiple developers share same Redis instance
2026-03-06 16:35:24 +08:00
Timebomb2018
391cd602a2 fix(mcp): MCP tool binds the information of the tool marketplace 2026-03-06 16:32:33 +08:00
yingzhao
5f56cc8056 Merge pull request #493 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): workflow upload bugfix
2026-03-06 16:18:30 +08:00
zhaoying
827ab27bef fix(web): workflow upload bugfix 2026-03-06 16:12:55 +08:00
Eternity
ccc67df8df feat(workflow): support multimodal context 2026-03-06 15:44:37 +08:00
yujiangping
82538c469f Merge branch 'fix/v0.2.6_yjp' into release/v0.2.6 2026-03-06 15:32:34 +08:00
yujiangping
076ceee29d fix(web): filter vision models for image2text and cleanup tool management
- Add vision capability filter for image2text model options in CreateModal
- Filter model options to only include models with 'vision' capability when type is 'image2text'
- Remove outdated file header comments from ToolManagement component
- Comment out 'market' tab from tabKeys array in ToolManagement
- Ensure image2text tool only displays compatible vision-capable models
2026-03-06 15:30:30 +08:00
yingzhao
822b73b015 Merge pull request #491 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): i18n update
2026-03-06 15:19:26 +08:00
zhaoying
862bff51cb fix(web): i18n update 2026-03-06 15:18:36 +08:00
yujiangping
247db844a4 fix:market 2026-03-06 15:11:50 +08:00
yujiangping
5495d32822 fix:conflict 2026-03-06 15:11:01 +08:00
yujiangping
bccbeaabe4 fix:tool market hidden 2026-03-06 15:09:05 +08:00
yujiangping
a496991400 Merge branch 'develop' into feature/tool_yjp 2026-03-06 15:03:57 +08:00
yujiangping
0ea83b4364 feat(web): enable MCP market configuration and service management
- Add market configuration API endpoints for creating, updating, and retrieving market configs
- Add market MCP listing and detail endpoints with support for activated services
- Implement MarketConfigModal component for configuring market connections with URL and API key
- Implement McpServiceModal component for viewing and managing MCP services from markets
- Add infinite scroll pagination for market sources and MCP services
- Add market connection status indicators (connected/disconnected/connecting states)
- Add i18n translations for market configuration UI (en and zh)
- Update Market component to display market sources with connection management
- Add MarketQuery type for market-specific API queries
- Refactor market data structure to match backend API response format
2026-03-06 14:55:45 +08:00
Mark
03676b7adc Merge pull request #490 from SuanmoSuanyangTechnology/fix/mutimodal
fix(agent and model)
2026-03-06 14:48:34 +08:00
Timebomb2018
af6fde414f fix(agent and model):
1. From the model square to the model list, the added models are initially set to be inactive. When manually activating them, a mandatory API key configuration is required.
2. Copying of applications (agent, workflow, multi_agent)
2026-03-06 14:40:07 +08:00
lanceyq
d069809001 [changes] AI review and correction of code 2026-03-06 14:35:16 +08:00
lanceyq
fc240849cf [add] Semantic pruning is unified with the ontology engineering scenario. 2026-03-06 14:12:03 +08:00
yingzhao
61d2a328fe Merge pull request #488 from SuanmoSuanyangTechnology/fix/release_web_zy
feat(web): change memory extraction pruning_scene control
2026-03-06 14:02:18 +08:00
zhaoying
fed0ae8e9c feat(web): change memory extraction pruning_scene control 2026-03-06 13:54:33 +08:00
yingzhao
eaf0de453b Merge pull request #487 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-03-06 13:38:56 +08:00
Eternity
e833db954a feat(chat): add message_id field to chat API response 2026-03-06 13:37:16 +08:00
zhaoying
0b2651f4ed fix(web): chat file delete bugfix 2026-03-06 13:36:50 +08:00
Ke Sun
10c677a6fd Merge pull request #486 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
2026-03-06 12:29:07 +08:00
zhaoying
3398c4737a fix(web): Official models do not support configuration 2026-03-06 12:27:52 +08:00
Ke Sun
a008f5fbef Merge pull request #485 from SuanmoSuanyangTechnology/feature/default-ontology
[add] Default label for the entity type
2026-03-06 12:27:23 +08:00
zhaoying
6a42e73667 fix(web): Pre-generate attachment preview links 2026-03-06 12:25:09 +08:00
zhaoying
7611db19f3 fix(web): app upload jump add delay 2026-03-06 12:06:32 +08:00
lanceyq
d3399dfaf5 [add] Default label for the entity type 2026-03-06 11:49:02 +08:00
yingzhao
248f0d95ac Merge pull request #484 from SuanmoSuanyangTechnology/fix/release_web_zy
feat(web): default ontology hidden operate
2026-03-06 11:30:38 +08:00
zhaoying
5c39d841ee feat(web): default ontology hidden operate 2026-03-06 11:29:32 +08:00
yingzhao
87be67cb9a Merge pull request #482 from SuanmoSuanyangTechnology/fix/release_web_zy
Fix/release web zy
2026-03-06 10:51:04 +08:00
zhaoying
1a08bea864 fix(web): update i18n 2026-03-06 10:50:16 +08:00
zhaoying
bc4406cec6 feat(web): ontology add warning info 2026-03-06 10:49:18 +08:00
Mark
4206c849c3 Merge pull request #481 from SuanmoSuanyangTechnology/fix/mutimodal
feat(model apikey)
2026-03-06 10:46:49 +08:00
zhaoying
3f052b7798 feat(web): ontology add warning info 2026-03-06 10:45:12 +08:00
Timebomb2018
f1c5f24f6b feat(model apikey): Add validation modification for adding the apikey to the muti_modal model 2026-03-06 10:43:13 +08:00
Mark
e981c95225 Merge pull request #478 from SuanmoSuanyangTechnology/fix/db-connect-leak
fix(db): fix database connection leak
2026-03-06 10:40:35 +08:00
Ke Sun
4ce4f53835 Merge pull request #480 from SuanmoSuanyangTechnology/fix/celery-env-hijack
Fix/celery env hijack
2026-03-06 10:37:27 +08:00
Ke Sun
f16e369540 fix(celery): remove legacy environment variables to prevent CLI hijacking
- Remove BROKER_URL environment variable to prevent Celery CLI override
- Remove RESULT_BACKEND environment variable to prevent Celery CLI override
- Remove CELERY_BROKER environment variable to prevent Celery CLI override
- Remove CELERY_BACKEND environment variable to prevent Celery CLI override
- Add clarifying comments explaining the purpose of neutralizing legacy vars
- Ensures canonical broker and backend URLs are not accidentally overridden by Celery's CLI/Click integration
2026-03-06 10:37:00 +08:00
Ke Sun
47bf93d65e docs(config): update Celery environment variable naming convention
- Replace BROKER_URL and RESULT_BACKEND with REDIS_DB_CELERY_BROKER and REDIS_DB_CELERY_BACKEND in README.md
- Replace BROKER_URL and RESULT_BACKEND with REDIS_DB_CELERY_BROKER and REDIS_DB_CELERY_BACKEND in README_CN.md
- Update api/env.example with new variable names and add deprecation notice
- Add reference to celery-env-bug-report.md documentation explaining why old variable names are avoided
- Prevents environment variable hijacking by Celery CLI when using standard naming conventions
2026-03-06 10:28:03 +08:00
Ke Sun
5c2e0af33e fix(celery): resolve environment variable hijacking by Celery CLI
- Rename CELERY_BROKER and CELERY_BACKEND to REDIS_DB_CELERY_BROKER and REDIS_DB_CELERY_BACKEND to avoid Celery CLI prefix matching hijacking
- Build canonical broker and backend URLs and force them into os.environ to prevent override by stray environment variables
- Add logging for Celery app initialization with sanitized connection details
- Update celery_app.py to use pre-built URL variables instead of inline construction
- Add documentation reference to celery-env-bug-report.md explaining the environment variable naming convention
- Prevents Celery CLI's Click framework from intercepting broker/backend configuration through environment variables
2026-03-06 10:28:03 +08:00
Eternity
aaa0410781 fix(db): fix database connection leak 2026-03-06 10:21:32 +08:00
Mark
366b148f3d Merge pull request #475 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(tool and api key)
2026-03-06 10:17:10 +08:00
Ke Sun
6a265de31c Merge pull request #477 from SuanmoSuanyangTechnology/fix/ontology
[changes] From the perspective of logical judgment, to determine the …
2026-03-05 19:02:16 +08:00
lanceyq
c3707f543c [changes] From the perspective of logical judgment, to determine the situation of duplicate names 2026-03-05 18:59:23 +08:00
Ke Sun
8de368348b Merge pull request #476 from SuanmoSuanyangTechnology/fix/ontology
Fix/ontology
2026-03-05 18:38:42 +08:00
lanceyq
d052c31ac5 [changes] The pre-query at the service layer has been removed. The DB constraint ensures a unique single source of truth. 2026-03-05 18:36:12 +08:00
lanceyq
31320afed6 Merge branch 'fix/ontology' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/ontology 2026-03-05 18:19:39 +08:00
lanceyq
7afe507296 [add] Memory configuration adds uniqueness detection 2026-03-05 18:19:30 +08:00
lanceyq
4188443101 [add] Repeatability test 2026-03-05 18:19:30 +08:00
lanceyq
a1fc0fd394 [add] Added checks for idempotency of the ontology project 2026-03-05 18:19:30 +08:00
lanceyq
71fe35533d [add] Memory configuration adds uniqueness detection 2026-03-05 18:15:31 +08:00
lanceyq
a2ed335e59 [add] Repeatability test 2026-03-05 18:04:46 +08:00
lanceyq
8422a05d74 [add] Added checks for idempotency of the ontology project 2026-03-05 17:22:18 +08:00
Timebomb2018
139ae3bcb4 fix(tool and api key)
1. Tool name duplication check;
2. The default QPS value of API key is set to 100.
2026-03-05 17:08:09 +08:00
yingzhao
a0a57d5fbb Merge pull request #474 from SuanmoSuanyangTechnology/fix/release_web_zy
fix(web): adjust variable validation timing during Agent debugging
2026-03-05 17:07:13 +08:00
zhaoying
80fa88ac37 fix(web): adjust variable validation timing during Agent debugging 2026-03-05 17:05:48 +08:00
Ke Sun
0fda1c752d Merge pull request #473 from SuanmoSuanyangTechnology/fix/default
Fix/default
2026-03-05 17:05:15 +08:00
lanceyq
6c2fc75199 [fix] Memory configuration, addition of default identifiers for the ontology scene 2026-03-05 17:02:14 +08:00
lanceyq
2cb6aeb022 [fix] The interface returns "is_system_default" 2026-03-05 17:02:14 +08:00
yingzhao
e0174f75b3 Merge pull request #471 from SuanmoSuanyangTechnology/feature/memory_zy
feat(web): memory config & ontology add default tag
2026-03-05 16:50:10 +08:00
yingzhao
51d04746a3 Merge branch 'release/v0.2.6' into feature/memory_zy 2026-03-05 16:49:46 +08:00
yingzhao
3b08d6c320 Merge pull request #470 from SuanmoSuanyangTechnology/feature/form_zy
feat(web): knowledge add form rules
2026-03-05 16:45:13 +08:00
zhaoying
495c5802a0 feat(web): knowledge add form rules 2026-03-05 16:43:59 +08:00
zhaoying
621b074b3d feat(web): memory config & ontology add default tag 2026-03-05 16:36:39 +08:00
Ke Sun
6df32983b5 Merge pull request #469 from SuanmoSuanyangTechnology/fix/bug
[fix] Remove the unused ones
2026-03-05 16:23:25 +08:00
lanceyq
9c9fe9dde7 [fix] Remove the unused ones 2026-03-05 16:21:27 +08:00
Ke Sun
128c1a6178 Merge pull request #467 from SuanmoSuanyangTechnology/fix/api-service
[changes]
2026-03-05 15:20:14 +08:00
yingzhao
f90e102854 Merge pull request #468 from SuanmoSuanyangTechnology/feature/model_zy
feat(web): file type add default value
2026-03-05 15:15:56 +08:00
zhaoying
2e1eb9a5a6 feat(web): file type add default value 2026-03-05 15:12:18 +08:00
lanceyq
60a95f6556 [changes] 2026-03-05 15:02:01 +08:00
Mark
218637e81d [add] migration script 2026-03-05 14:42:42 +08:00
Mark
404f78af0f Merge tag 'v0.2.5-hotfix-1' into develop
v2.0.5-hotfix

# Conflicts:
#	api/app/cache/__init__.py
#	api/app/cache/memory/__init__.py
#	api/app/celery_app.py
#	api/app/core/config.py
#	web/src/api/memory.ts
#	web/src/views/Workflow/components/Chat/Chat.tsx
2026-03-05 14:37:35 +08:00
Mark
6301528301 Merge pull request #466 from SuanmoSuanyangTechnology/feature/agent-variables
Enhance workflow input handling and add legacy dify compatibility
2026-03-05 14:21:31 +08:00
lixiangcheng1
6feea968e0 Merge branch 'feature/knowledge_lxc' into develop 2026-03-05 14:21:13 +08:00
lixiangcheng1
b5199b2eb9 【ADD】list operational mcp servers 2026-03-05 14:18:33 +08:00
Eternity
78ce2a9a8b feat(workflow): support multimodal input 2026-03-05 14:16:30 +08:00
yingzhao
6ed542b007 Merge pull request #465 from SuanmoSuanyangTechnology/feature/model_zy
Feature/model zy
2026-03-05 12:29:45 +08:00
Ke Sun
5322b0c4a3 Merge pull request #464 from SuanmoSuanyangTechnology/fix/ontology-scene
[fix] Deleting the default scene results in a 400 status code. A unif…
2026-03-05 11:26:01 +08:00
Eternity
a72d5d2c77 fix(workflow): add backward compatibility for old dify configs 2026-03-05 11:18:48 +08:00
Eternity
16c1cbe24f feat(agent): add input variable validation 2026-03-05 11:17:56 +08:00
yingzhao
0d8f4c76e7 Merge pull request #463 from SuanmoSuanyangTechnology/feature/workflow_import_zy
feat(web): chat variable support paragraph
2026-03-05 11:07:29 +08:00
lanceyq
e511b14933 [fix] Deleting the default scene results in a 400 status code. A unified language pop-up prompt is displayed. 2026-03-05 11:06:46 +08:00
zhaoying
b5ba53208e feat(web): chat variable support paragraph 2026-03-05 11:05:51 +08:00
yingzhao
b8bfb4d0c5 Merge pull request #462 from SuanmoSuanyangTechnology/feature/memory_zy
feat(web): add SYSTEM_DEFAULT_SCENE_CANNOT_DELETE error i18n
2026-03-05 10:59:59 +08:00
zhaoying
1b666638bc feat(web): add SYSTEM_DEFAULT_SCENE_CANNOT_DELETE error i18n 2026-03-05 10:58:25 +08:00
Mark
2bd364eca3 [add] migration script 2026-03-05 10:46:31 +08:00
zhaoying
f27fc51801 Merge branch 'develop' into feature/model_zy 2026-03-05 10:32:02 +08:00
Mark
0f85eff76b Merge pull request #460 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(model and app)
2026-03-05 10:31:50 +08:00
zhaoying
0def474cc2 feat(web): app's chat support audio/video/document file 2026-03-05 10:30:35 +08:00
Timebomb2018
590ec3a446 feat(model and app):
1. Increase support for visual models and multimodal models;
2. The application and workflow can input various multimodal files such as images, documents, audio, and videos.
2026-03-05 09:55:54 +08:00
Ke Sun
23bfdcefef Merge pull request #458 from SuanmoSuanyangTechnology/fix/RAG-memory
Fix/rag memory
2026-03-04 19:09:03 +08:00
lanceyq
647a978865 [fix] Restore task 2026-03-04 19:07:40 +08:00
Ke Sun
86f72100f0 Merge pull request #457 from SuanmoSuanyangTechnology/fix/External-API
Fix/external api
2026-03-04 18:24:32 +08:00
yingzhao
8b255259ba Merge pull request #459 from SuanmoSuanyangTechnology/feature/workflow_import_zy
fix(web): chat loading fix
2026-03-04 18:07:22 +08:00
zhaoying
8aad8faae9 fix(web): chat loading fix 2026-03-04 18:05:54 +08:00
lanceyq
420f391f3c [fix] Fixed tuple unpacking and moved UUID conversion into the try block. 2026-03-04 18:01:56 +08:00
lanceyq
817221347f [fix] Preserve full result dict and default status to "unknown" instead of "success". 2026-03-04 17:57:58 +08:00
lanceyq
13dce5e265 Merge branch 'fix/RAG-memory' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/RAG-memory 2026-03-04 17:48:44 +08:00
lanceyq
850d9ee70b [changes] Hide the user knowledge base and unify the display of memory capacity 2026-03-04 17:48:25 +08:00
lanceyq
ba36ccb21f [changes] Hide the user knowledge base and unify the display of memory capacity 2026-03-04 17:46:13 +08:00
lanceyq
f712754927 Merge branch 'fix/External-API' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/External-API 2026-03-04 17:28:33 +08:00
lanceyq
efe3865aa4 [fix] Fix the external write memory API 2026-03-04 17:28:24 +08:00
lanceyq
53dbe2f436 [fix] Fix the external write memory API 2026-03-04 17:26:30 +08:00
yingzhao
720498084b Merge pull request #456 from SuanmoSuanyangTechnology/feature/form_zy
Feature/form zy
2026-03-04 17:06:22 +08:00
zhaoying
f5eda38dc9 feat(web): ontology extract add form rules 2026-03-04 17:04:25 +08:00
yingzhao
8ada221777 Merge pull request #455 from SuanmoSuanyangTechnology/feature/form_zy
Feature/form zy
2026-03-04 16:47:14 +08:00
zhaoying
4ee198813a feat(web): custom tool add form rules 2026-03-04 16:46:25 +08:00
zhaoying
440e8acd99 feat(web): mcp tool add form rules 2026-03-04 16:42:15 +08:00
Mark
37325e9802 Merge pull request #452 from SuanmoSuanyangTechnology/fix/workflow-api-stream
fix(workflow): fix incorrect fields in streaming API output
2026-03-04 16:06:03 +08:00
Eternity
778bc4bd70 fix(workflow): fix incorrect fields in streaming API output 2026-03-04 15:58:49 +08:00
lixiangcheng1
f78f59ec42 Merge branch 'feature/knowledge_lxc' into develop 2026-03-04 15:42:06 +08:00
lixiangcheng1
d4c4160215 【ADD]Knowledge base retrieval supports file set retrieval 2026-03-04 15:28:17 +08:00
yujiangping
85aea97c21 chore(web): disable market tab in tool management
- Comment out Market component rendering in ToolManagement view
- Update LastEditTime timestamp in file header
- Market tab functionality temporarily disabled pending further developmen
2026-03-04 15:13:14 +08:00
yujiangping
b075cad4de Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop 2026-03-04 15:03:04 +08:00
yujiangping
f326febc8a feat:tool market add 2026-03-04 14:40:27 +08:00
Ke Sun
1738e45090 Merge pull request #451 from SuanmoSuanyangTechnology/fix/memory_incorememt
[changes] Setting the environment variable for the scheduled task time
2026-03-04 14:22:38 +08:00
lanceyq
6e758faa37 [changes] Using Pydantic to standardize the time data for scheduled tasks 2026-03-04 14:17:45 +08:00
Ke Sun
32e79c5df0 Fix/interest distribution (#445)
* [fix] Revising the judgment method for the interest analysis tags

* [fix] Revising the judgment method for the interest analysis tags

* [add] Set cache for the distribution of interest tags

* [fix] Revising the judgment method for the interest analysis tags

* [add] Set cache for the distribution of interest tags

* [changes] 1.Use structured logs;
          2.Align the type and default value of "end_user_id" with the semantic meaning of "required".
2026-03-04 14:06:50 +08:00
Ke Sun
da4a1f536d Merge pull request #450 from SuanmoSuanyangTechnology/fix/workflow-output
fix(workflow): rename output message field
2026-03-04 13:53:08 +08:00
lanceyq
b3af757167 [changes] Setting the environment variable for the scheduled task time 2026-03-04 13:51:31 +08:00
Eternity
82794f051a fix(workflow): rename output message field 2026-03-04 13:49:33 +08:00
Mark
c041d24989 Merge pull request #446 from SuanmoSuanyangTechnology/feature/agent-variable
fix(workflow): rename output message field
2026-03-04 12:32:04 +08:00
yingzhao
1d662fb63e Merge pull request #448 from SuanmoSuanyangTechnology/feature/memory_zy
feat(web): short term detail use Markdown
2026-03-04 12:27:49 +08:00
yingzhao
d1933d2aef Merge pull request #447 from SuanmoSuanyangTechnology/feature/workflow_import_zy
feat(web): workflow chat use content replace chunk
2026-03-04 12:25:06 +08:00
Eternity
163872be6e fix(workflow): rename output message field 2026-03-04 12:23:17 +08:00
zhaoying
14fcb66a9c feat(web): short term detail use Markdown 2026-03-04 12:19:48 +08:00
lanceyq
c488eb0cd0 [changes] 1.Use structured logs;
2.Align the type and default value of "end_user_id" with the semantic meaning of "required".
2026-03-04 12:17:34 +08:00
zhaoying
91d20f7272 feat(web): workflow chat use content replace chunk 2026-03-04 12:12:21 +08:00
lanceyq
c3d7963fe0 Merge branch 'fix/interest_distribution' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/interest_distribution 2026-03-04 12:10:08 +08:00
lanceyq
c31a92bf01 [add] Set cache for the distribution of interest tags 2026-03-04 12:10:00 +08:00
lanceyq
b5703c1b82 [fix] Revising the judgment method for the interest analysis tags 2026-03-04 12:09:59 +08:00
lanceyq
df34735a9b [add] Set cache for the distribution of interest tags 2026-03-04 12:08:57 +08:00
zhaoying
31bee889d7 feat(web): model add is_vision/is_omni config 2026-03-04 11:52:54 +08:00
Ke Sun
b3ba0a6ed6 Merge pull request #443 from SuanmoSuanyangTechnology/fix/memory_incorememt
[changes] The timing of the memory increment task has been changed fr…
2026-03-04 11:16:58 +08:00
lanceyq
ce3b7897d7 Merge branch 'fix/interest_distribution' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/interest_distribution 2026-03-04 11:06:20 +08:00
lanceyq
9115ad6950 [fix] Revising the judgment method for the interest analysis tags 2026-03-04 11:06:08 +08:00
yingzhao
c6b76438f4 Merge pull request #444 from SuanmoSuanyangTechnology/feature/memory_zy
feat(web): change interest distribution api
2026-03-04 11:00:56 +08:00
zhaoying
68c4c7429c feat(web): change interest distribution api 2026-03-04 10:59:29 +08:00
lanceyq
8466c8e019 [fix] Revising the judgment method for the interest analysis tags 2026-03-03 23:30:54 +08:00
lanceyq
d899b27448 [changes] The timing of the memory increment task has been changed from relative time to absolute time. 2026-03-03 22:46:05 +08:00
Ke Sun
66c153f1ad refactor(api): improve memory service dependency injection and code organization
- Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance
- Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports)
- Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation
- Fix import organization in memory_short_service.py to group related imports and improve readability
- Reorganize imports in user_memory_service.py to follow consistent import ordering patterns
- Update ShortService instantiation in analytics_memory_types to pass db parameter
- Remove module-level db instance initialization in favor of caller-managed database session lifecycle
- Add type annotations to method signatures (end_user_id: str, db: Session, return types)
- Improve code formatting and spacing consistency across memory service files
2026-03-03 16:48:34 +08:00
yingzhao
c6c7a1827c Merge pull request #440 from SuanmoSuanyangTechnology/feature/workflow_import_zy
Feature/workflow import zy
2026-03-03 15:33:13 +08:00
yujiangping
8fdaebbe6e Merge branch 'fix/release_web_yjp' into develop 2026-03-03 15:02:20 +08:00
zhaoying
9a98ccff2c feat(web): agent compare chat add variables 2026-03-03 14:48:50 +08:00
yujiangping
ee4027c561 feat(web): enhance knowledge base sharing with stop share feedback
- Fix file download URL to use absolute API path instead of apiPrefix variable
- Add stopShareSuccess i18n message for English locale
- Add stopShareSuccess i18n message for Chinese locale
- Update ShareModal to display different success messages based on share toggle state
- Show "Sharing is off" message when disabling knowledge base sharing
- Improve user feedback when toggling share status on/off
2026-03-03 14:47:24 +08:00
zhaoying
7f36a06f26 fix(web): update share version modal's title 2026-03-03 14:05:02 +08:00
zhaoying
0826a34d8b fix(web): http node body variable filter update 2026-03-03 13:57:31 +08:00
zhaoying
1792cb4d93 feat(web): chat add variables 2026-03-03 13:48:50 +08:00
Ke Sun
304ccef101 chore(api): organize imports and refactor database context management 2026-03-03 12:30:09 +08:00
Mark
bdc22c892d Merge pull request #437 from SuanmoSuanyangTechnology/fix/agent-files
fix(agent): fix issue where default runtime file list configuration was empty
2026-03-03 12:27:37 +08:00
Eternity
a5034e84ba fix(agent): fix issue where default runtime file list configuration was empty 2026-03-03 12:19:43 +08:00
Ke Sun
6e2de96fed Merge pull request #436 from SuanmoSuanyangTechnology/refactor/modify-path
[changes] modify-path
2026-03-03 12:18:15 +08:00
lanceyq
2b6d86e591 [changes] 2026-03-03 11:49:33 +08:00
Mark
8c6f4cb117 Merge pull request #434 from SuanmoSuanyangTechnology/feature/app-share-config
feat(app): add API to retrieve app configuration fields
2026-03-03 11:25:35 +08:00
yingzhao
16d4b32eb7 Merge pull request #435 from SuanmoSuanyangTechnology/feature/workflow_import_zy
fix(web): agent's variables init update
2026-03-03 11:24:10 +08:00
zhaoying
45a64dbbac fix(web): agent's variables init update 2026-03-03 11:15:14 +08:00
Eternity
537668b463 Merge pull request #432 from SuanmoSuanyangTechnology/feature/workflow_import_zy
Feature/workflow import zy
2026-03-03 11:08:24 +08:00
Eternity
07fea23dd0 feat(app): add API to retrieve app configuration fields 2026-03-03 10:48:22 +08:00
yingzhao
cef14291f0 Merge pull request #432 from SuanmoSuanyangTechnology/feature/workflow_import_zy
Feature/workflow import zy
2026-03-03 10:29:32 +08:00
yingzhao
bbde0588af Merge pull request #433 from SuanmoSuanyangTechnology/feature/form_zy
fix(web): change string regExp
2026-03-03 10:29:10 +08:00
zhaoying
aa7d52568b fix(web): change string regExp 2026-03-03 10:24:21 +08:00
yingzhao
f39c77ac70 Merge branch 'develop' into feature/workflow_import_zy 2026-03-03 10:16:59 +08:00
zhaoying
aa733354e8 fix(web): Editor input type add blur event 2026-03-03 10:14:36 +08:00
yingzhao
7cec966979 Merge pull request #431 from SuanmoSuanyangTechnology/feature/workflow_import_zy
feat(web): update file type
2026-03-02 18:45:43 +08:00
yingzhao
74865d2cf2 Merge pull request #430 from SuanmoSuanyangTechnology/feature/form_zy
revert(web): revert file
2026-03-02 18:44:51 +08:00
zhaoying
c9a8753473 revert(web): revert file 2026-03-02 18:38:08 +08:00
zhaoying
ce8a2cbe34 feat(web): update file type 2026-03-02 18:32:19 +08:00
yingzhao
c0fdd0c6d3 Merge pull request #429 from SuanmoSuanyangTechnology/feature/form_zy
Feature/form zy
2026-03-02 18:29:54 +08:00
yingzhao
88bfcfe6cd Merge pull request #428 from SuanmoSuanyangTechnology/feature/workflow_import_zy
Feature/workflow import zy
2026-03-02 18:29:25 +08:00
zhaoying
c4dcf1fd65 Merge branch 'feature/form_zy' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/form_zy 2026-03-02 18:26:23 +08:00
zhaoying
6cebddf893 feat(web): workflow runtime add error info 2026-03-02 18:14:36 +08:00
Mark
1738ed3664 Merge pull request #427 from SuanmoSuanyangTechnology/fix/workflow-variable
fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug
2026-03-02 17:55:54 +08:00
zhaoying
37ddcb91ac feat(web): update text 2026-03-02 17:51:30 +08:00
Eternity
574ab4506b feat(workflow): add placeholder node for unknown types 2026-03-02 17:37:59 +08:00
zhaoying
81353538e5 feat(web): http node config support editor 2026-03-02 17:26:24 +08:00
zhaoying
5abfcdfbe8 feat(web): add unknown node 2026-03-02 17:07:29 +08:00
zhaoying
9962a61c21 feat(web): update app api 2026-03-02 15:54:35 +08:00
Eternity
5cf2b08777 fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug 2026-03-02 14:59:12 +08:00
zhaoying
9be1c01b70 feat(web): chat content support scroll 2026-03-02 14:43:44 +08:00
zhaoying
62b2ecdfc2 feat(web): form add rules 2026-03-02 14:41:58 +08:00
zhaoying
2ff9000d25 feat(web): form add rules 2026-03-02 14:39:47 +08:00
Ke Sun
5829148ce4 Merge pull request #425 from SuanmoSuanyangTechnology/fix/2.6-bug
Fix/2.6 bug
2026-03-02 14:27:33 +08:00
lanceyq
8e15a340f6 [changes]Correct log output, log level, and pruning conditions 2026-03-02 12:09:10 +08:00
yingzhao
1270b7cdd8 Merge pull request #426 from SuanmoSuanyangTechnology/feature/memory_zy
feat(web): memoryExtractionEngine add pruning
2026-03-02 11:54:24 +08:00
lanceyq
7c02fe8148 Merge branch 'fix/2.6-bug' of github.com:SuanmoSuanyangTechnology/MemoryBear into fix/2.6-bug 2026-03-02 11:49:37 +08:00
lanceyq
4ac63e1c23 [add]Complete the interface integration for the display of semantic pruning for streaming output. 2026-03-02 11:49:28 +08:00
lanceyq
4aeb653ed2 [fix]Fix the display issue of semantic chunking for streaming output 2026-03-02 11:49:28 +08:00
lanceyq
2d5c2de613 [add]New semantic pruning effect display for streaming output 2026-03-02 11:49:28 +08:00
lanceyq
96590941cf [add]The semantic pruning function is activated, removing the protection of question-answer pairs. 2026-03-02 11:49:28 +08:00
lanceyq
0655ff4a91 [fix]Correct the flaws existing in the semantic segmentation method 2026-03-02 11:49:28 +08:00
lanceyq
0ba370052e [fix]Address the shortcomings of intelligent pruning 2026-03-02 11:49:28 +08:00
lanceyq
4d59e04aba [changes]Ensure that there are sufficient labels for LLM to process, and control the number of label returns. 2026-03-02 11:49:28 +08:00
lanceyq
6db6c33564 [fix]Reduce the default number of items returned for popular tags 2026-03-02 11:49:28 +08:00
lanceyq
ed0d963aec [fix]Modify the person who generates the user summary 2026-03-02 11:49:28 +08:00
lanceyq
3a36d038ee [fix]Reconstructing memory incremental statistical scheduling task 2026-03-02 11:49:28 +08:00
lanceyq
3d068a9c96 [fix]Complete the API call logic for the homepage 2026-03-02 11:49:28 +08:00
zhaoying
87df352adc feat(web): memoryExtractionEngine add pruning 2026-03-02 11:42:46 +08:00
lanceyq
8b546b7366 [add]Complete the interface integration for the display of semantic pruning for streaming output. 2026-02-28 19:26:16 +08:00
Mark
77ea0680fb [add] migration script 2026-02-28 19:22:13 +08:00
乐力齐
4c592bf7e3 Feature/default ontology (#424)
* [add]Create a workspace and initialize the default ontology engineering scenario

* [add]The language parameters for creating the workspace determine the default language for switching in the ontology project.

* [changes]Standardized return format

* [add]The default ontology is associated with the default configuration.

* [add]Create a workspace and initialize the default ontology engineering scenario

* [add]The language parameters for creating the workspace determine the default language for switching in the ontology project.

* [changes]Standardized return format

* [add]The default ontology is associated with the default configuration.
2026-02-28 18:58:33 +08:00
lixinyue11
6718553bf4 Fix/develop memory rag (#419)
* fix_rag/fast summary

* fix_rag/fast summary
2026-02-28 18:47:08 +08:00
Mark
79dc6f3f69 Merge pull request #417 from SuanmoSuanyangTechnology/fix/workflow-adapter
fix(workflow): enhance Dify import types, templates and tool nodes
2026-02-28 18:46:56 +08:00
Ke Sun
8df72d2822 Merge pull request #423 from SuanmoSuanyangTechnology/release/v0.2.5
Release/v0.2.5
2026-02-28 18:38:18 +08:00
Mark
3ce5926689 Merge pull request #416 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(app)
2026-02-28 18:23:14 +08:00
lanceyq
035464c0ac [fix]Fix the display issue of semantic chunking for streaming output 2026-02-28 18:19:44 +08:00
yingzhao
f1fcffbfc0 Merge pull request #420 from SuanmoSuanyangTechnology/feature/workflow_import_zy
feat(web): workflow import & export
2026-02-28 18:02:24 +08:00
zhaoying
b79fe07052 feat(web): workflow import & export 2026-02-28 18:01:00 +08:00
lanceyq
e6aa0e0e10 [add]New semantic pruning effect display for streaming output 2026-02-28 17:51:12 +08:00
Eternity
54700e6fbe fix(workflow): fix exceptions when importing configs from Dify 2026-02-28 17:32:35 +08:00
lanceyq
3a0671c661 [add]The semantic pruning function is activated, removing the protection of question-answer pairs. 2026-02-28 17:18:42 +08:00
Timebomb2018
1037729fb3 fix(model): The custom models in the model list can batch add APIkeys through the provider 2026-02-28 16:51:56 +08:00
Timebomb2018
5f211620c5 fix(app): Lock the conversation with the application dialogue 2026-02-28 14:01:49 +08:00
Timebomb2018
cb6a3aae9e Merge branch 'refs/heads/feature/20260105_xjn' into feature/agent-tool_xjn 2026-02-28 13:59:31 +08:00
Mark
5e512df3d4 Merge pull request #415 from SuanmoSuanyangTechnology/feature/workflow-adapter-dify
feat(workflow): add Dify workflow import adapter and related APIs
2026-02-28 13:18:30 +08:00
Eternity
9916cf3265 feat(workflow): add Dify workflow import adapter and related APIs 2026-02-28 11:26:52 +08:00
lanceyq
f7aed9dd98 [fix]Correct the flaws existing in the semantic segmentation method 2026-02-27 16:45:34 +08:00
lanceyq
5253cf3899 [fix]Address the shortcomings of intelligent pruning 2026-02-27 16:09:22 +08:00
lanceyq
f7d92be5ea [changes]Ensure that there are sufficient labels for LLM to process, and control the number of label returns. 2026-02-27 15:08:06 +08:00
lanceyq
97d8168824 [fix]Reduce the default number of items returned for popular tags 2026-02-27 14:59:28 +08:00
lanceyq
550bd4da23 [fix]Modify the person who generates the user summary 2026-02-27 14:47:23 +08:00
lixiangcheng1
4f0b653a82 【fix]The complexity and volume of the document content require an extended timeframe 2026-02-26 19:04:42 +08:00
Timebomb2018
616709acbb Merge branch 'refs/heads/feature/20260105_xjn' into feature/agent-tool_xjn 2026-02-26 16:18:21 +08:00
Timebomb2018
67053ab8ae fix(workspace member): After the space inviter is removed, it can still be invited again. 2026-02-26 13:35:07 +08:00
lixiangcheng1
33238d34c9 [fix]Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) 2026-02-26 10:17:44 +08:00
280 changed files with 14920 additions and 4939 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@ search_results.json
api/migrations/versions
tmp
files
powers/
# Exclude dep files
huggingface.co/

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (Using Redis as broker)
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
# JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (使用Redis作为broker)
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
# JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
apt install -y ghostscript && \
apt install -y libmagic1
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@@ -10,7 +10,6 @@ from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
@@ -29,7 +29,8 @@ async def get_redis_connection():
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
"""设置Redis键值
Args:
@@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
@@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
@@ -88,25 +93,25 @@ class RedisSubscriber:
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
@@ -127,7 +132,7 @@ class RedisSubscriber:
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
@@ -141,7 +146,7 @@ class RedisSubscriber:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
@@ -153,7 +158,7 @@ class RedisSubscriber:
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
@@ -163,32 +168,33 @@ class RedisSubscriber:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
@@ -196,6 +202,6 @@ class RedisPubSubManager:
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()

View File

@@ -2,7 +2,9 @@
Cache 缓存模块
提供各种缓存功能的统一入口
注意隐性记忆和情绪建议已迁移到数据库存储不再使用Redis缓存
"""
from .memory import InterestMemoryCache
__all__ = []
__all__ = [
"InterestMemoryCache",
]

View File

@@ -2,7 +2,11 @@
Memory 缓存模块
提供记忆系统相关的缓存功能
注意隐性记忆和情绪建议已迁移到数据库存储不再使用Redis缓存
"""
from .interest_memory import InterestMemoryCache
from .activity_stats_cache import ActivityStatsCache
__all__ = []
__all__ = [
"InterestMemoryCache",
"ActivityStatsCache",
]

View File

@@ -0,0 +1,124 @@
"""
Recent Activity Stats Cache
记忆提取活动统计缓存模块
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放
查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
ACTIVITY_STATS_CACHE_EXPIRE = 86400
class ActivityStatsCache:
"""记忆提取活动统计缓存类"""
PREFIX = "cache:memory:activity_stats"
@classmethod
def _get_key(cls, workspace_id: str) -> str:
"""生成 Redis key
Args:
workspace_id: 工作空间ID
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
@classmethod
async def set_activity_stats(
cls,
workspace_id: str,
stats: Dict[str, Any],
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
) -> bool:
"""设置记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
stats: 统计数据,格式:
{
"chunk_count": int,
"statements_count": int,
"triplet_entities_count": int,
"triplet_relations_count": int,
"temporal_count": int,
}
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(workspace_id)
payload = {
"stats": stats,
"generated_at": datetime.now().isoformat(),
"workspace_id": workspace_id,
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_activity_stats(
cls,
workspace_id: str,
) -> Optional[Dict[str, Any]]:
"""获取记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
统计数据字典,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(workspace_id)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}")
return payload
logger.info(f"活动统计缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_activity_stats(
cls,
workspace_id: str,
) -> bool:
"""删除记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
是否删除成功
"""
try:
key = cls._get_key(workspace_id)
result = await aio_redis.delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
return False

122
api/app/cache/memory/interest_memory.py vendored Normal file
View File

@@ -0,0 +1,122 @@
"""
Interest Distribution Cache
兴趣分布缓存模块
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
"""
import json
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
INTEREST_CACHE_EXPIRE = 86400
class InterestMemoryCache:
"""兴趣分布缓存类"""
PREFIX = "cache:memory:interest_distribution"
@classmethod
def _get_key(cls, end_user_id: str, language: str) -> str:
"""生成 Redis key
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
@classmethod
async def set_interest_distribution(
cls,
end_user_id: str,
language: str,
data: List[Dict[str, Any]],
expire: int = INTEREST_CACHE_EXPIRE,
) -> bool:
"""设置用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(end_user_id, language)
payload = {
"data": data,
"generated_at": datetime.now().isoformat(),
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> Optional[List[Dict[str, Any]]]:
"""获取用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
兴趣分布列表,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
return payload.get("data")
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> bool:
"""删除用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
是否删除成功
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
return False

View File

@@ -7,20 +7,48 @@ from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
celery_app = Celery(
"redbear_tasks",
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
broker=_broker_url,
backend=_backend_url,
)
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
@@ -35,17 +63,17 @@ celery_app.conf.update(
accept_content=['json'],
result_serializer='json',
# 时区
timezone='Asia/Shanghai',
enable_utc=True,
# # 时区
# timezone='Asia/Shanghai',
# enable_utc=False,
# 任务追踪
task_track_started=True,
task_ignore_result=False,
# 超时设置
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
@@ -85,6 +113,8 @@ celery_app.conf.update(
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
},
)
@@ -92,7 +122,7 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)

View File

@@ -0,0 +1 @@
"""Configuration module for application settings."""

View File

@@ -0,0 +1,239 @@
"""默认本体场景配置
本模块定义系统预设的本体场景和实体类型配置。
这些配置用于在工作空间创建时自动初始化默认场景。
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
"""
# 在线教育场景配置
ONLINE_EDUCATION_SCENE = {
"name_chinese": "在线教育",
"name_english": "Online Education",
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
"types": [
{
"name_chinese": "学生",
"name_english": "Student",
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
},
{
"name_chinese": "教师",
"name_english": "Teacher",
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
},
{
"name_chinese": "课程",
"name_english": "Course",
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
},
{
"name_chinese": "作业",
"name_english": "Assignment",
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
},
{
"name_chinese": "成绩",
"name_english": "Grade",
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
},
{
"name_chinese": "考试",
"name_english": "Exam",
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
},
{
"name_chinese": "教室",
"name_english": "Classroom",
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
},
{
"name_chinese": "学科",
"name_english": "Subject",
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
},
{
"name_chinese": "教材",
"name_english": "Textbook",
"description_chinese": "教学使用的书籍或资料包含书名、作者、出版社、ISBN等属性",
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
},
{
"name_chinese": "班级",
"name_english": "Class",
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
},
{
"name_chinese": "学期",
"name_english": "Semester",
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
},
{
"name_chinese": "课时",
"name_english": "Class Hour",
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
},
{
"name_chinese": "教学计划",
"name_english": "Teaching Plan",
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
}
]
}
# 情感陪伴场景配置
EMOTIONAL_COMPANION_SCENE = {
"name_chinese": "情感陪伴",
"name_english": "Emotional Companion",
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
"types": [
{
"name_chinese": "用户",
"name_english": "User",
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
},
{
"name_chinese": "情绪",
"name_english": "Emotion",
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
},
{
"name_chinese": "活动",
"name_english": "Activity",
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
},
{
"name_chinese": "对话",
"name_english": "Conversation",
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
},
{
"name_chinese": "兴趣爱好",
"name_english": "Hobby",
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
},
{
"name_chinese": "日常事件",
"name_english": "Daily Event",
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
},
{
"name_chinese": "关系",
"name_english": "Relationship",
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
},
{
"name_chinese": "回忆",
"name_english": "Memory",
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
},
{
"name_chinese": "地点",
"name_english": "Location",
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
},
{
"name_chinese": "时间节点",
"name_english": "Time Point",
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
"description_english": "Important time markers, including attributes such as date, event, and significance"
},
{
"name_chinese": "目标",
"name_english": "Goal",
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
},
{
"name_chinese": "成就",
"name_english": "Achievement",
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
}
]
}
# 导出默认场景列表
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
"""获取场景名称(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景名称
"""
if language == "en":
return scene_config.get("name_english", scene_config.get("name_chinese"))
return scene_config.get("name_chinese")
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
"""获取场景描述(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景描述
"""
if language == "en":
return scene_config.get("description_english", scene_config.get("description_chinese"))
return scene_config.get("description_chinese")
def get_type_name(type_config: dict, language: str = "zh") -> str:
"""获取类型名称(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型名称
"""
if language == "en":
return type_config.get("name_english", type_config.get("name_chinese"))
return type_config.get("name_chinese")
def get_type_description(type_config: dict, language: str = "zh") -> str:
"""获取类型描述(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型描述
"""
if language == "en":
return type_config.get("description_english", type_config.get("description_chinese"))
return type_config.get("description_chinese")

View File

@@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-
"""默认本体场景初始化器
本模块提供默认本体场景和类型的自动初始化功能。
在工作空间创建时,自动添加预设的本体场景和实体类型。
Classes:
DefaultOntologyInitializer: 默认本体场景初始化器
"""
import logging
from typing import List, Optional, Tuple
from uuid import UUID
from sqlalchemy.orm import Session
from app.config.default_ontology_config import (
DEFAULT_SCENES,
get_scene_name,
get_scene_description,
get_type_name,
get_type_description,
)
from app.core.logging_config import get_business_logger
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
class DefaultOntologyInitializer:
"""默认本体场景初始化器
负责在工作空间创建时自动初始化默认的本体场景和类型。
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
Attributes:
db: 数据库会话
scene_repo: 场景Repository
class_repo: 类型Repository
logger: 业务日志记录器
"""
def __init__(self, db: Session):
"""初始化
Args:
db: 数据库会话
"""
self.db = db
self.scene_repo = OntologySceneRepository(db)
self.class_repo = OntologyClassRepository(db)
self.logger = get_business_logger()
def initialize_default_scenes(
self,
workspace_id: UUID,
language: str = "zh"
) -> Tuple[bool, str]:
"""为工作空间初始化默认场景
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
如果创建失败,记录错误日志但不抛出异常。
Args:
workspace_id: 工作空间ID
language: 语言类型 ("zh""en"),默认为 "zh"
Returns:
Tuple[bool, str]: (是否成功, 错误信息)
"""
try:
self.logger.info(
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
)
scenes_created = 0
total_types_created = 0
# 遍历默认场景配置
for scene_config in DEFAULT_SCENES:
scene_name = get_scene_name(scene_config, language)
# 创建场景及其类型
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
if scene_id:
scenes_created += 1
# 统计类型数量
types_count = len(scene_config.get("types", []))
total_types_created += types_count
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene_id}, types_count={types_count}, language={language}"
)
else:
self.logger.warning(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}"
)
# 记录总体结果
self.logger.info(
f"默认场景初始化完成 - workspace_id={workspace_id}, "
f"language={language}, scenes_created={scenes_created}, "
f"total_types_created={total_types_created}"
)
# 如果至少创建了一个场景,视为成功
if scenes_created > 0:
return True, ""
else:
error_msg = "所有默认场景创建失败"
self.logger.error(
f"默认场景初始化失败 - workspace_id={workspace_id}, "
f"language={language}, error={error_msg}"
)
return False, error_msg
except Exception as e:
error_msg = f"默认场景初始化异常: {str(e)}"
self.logger.error(
f"默认场景初始化异常 - workspace_id={workspace_id}, "
f"language={language}, error={str(e)}",
exc_info=True
)
return False, error_msg
def _create_scene_with_types(
self,
workspace_id: UUID,
scene_config: dict,
language: str = "zh"
) -> Optional[UUID]:
"""创建场景及其类型
Args:
workspace_id: 工作空间ID
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
Optional[UUID]: 创建的场景ID失败返回None
"""
try:
scene_name = get_scene_name(scene_config, language)
scene_description = get_scene_description(scene_config, language)
# 检查是否已存在同名场景(支持向后兼容)
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
if existing_scene:
self.logger.info(
f"场景已存在,跳过创建 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
f"language={language}"
)
return None
# 创建场景记录,设置 is_system_default=true
scene_data = {
"scene_name": scene_name,
"scene_description": scene_description
}
scene = self.scene_repo.create(scene_data, workspace_id)
# 设置系统默认标识
scene.is_system_default = True
self.db.flush()
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
)
# 批量创建类型
types_config = scene_config.get("types", [])
types_created = self._batch_create_types(scene.scene_id, types_config, language)
self.logger.info(
f"场景类型创建完成 - scene_id={scene.scene_id}, "
f"types_created={types_created}/{len(types_config)}, language={language}"
)
return scene.scene_id
except Exception as e:
scene_name = get_scene_name(scene_config, language)
self.logger.error(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
exc_info=True
)
return None
def _batch_create_types(
self,
scene_id: UUID,
types_config: List[dict],
language: str = "zh"
) -> int:
"""批量创建实体类型
Args:
scene_id: 场景ID
types_config: 类型配置列表
language: 语言类型 ("zh""en")
Returns:
int: 成功创建的类型数量
"""
created_count = 0
for type_config in types_config:
try:
type_name = get_type_name(type_config, language)
type_description = get_type_description(type_config, language)
# 创建类型数据
class_data = {
"class_name": type_name,
"class_description": type_description
}
# 创建类型
ontology_class = self.class_repo.create(class_data, scene_id)
# 设置系统默认标识
ontology_class.is_system_default = True
self.db.flush()
created_count += 1
self.logger.debug(
f"类型创建成功 - class_name={type_name}, "
f"class_id={ontology_class.class_id}, "
f"scene_id={scene_id}, is_system_default=True, language={language}"
)
except Exception as e:
type_name = get_type_name(type_config, language)
self.logger.warning(
f"单个类型创建失败,继续创建其他类型 - "
f"class_name={type_name}, scene_id={scene_id}, "
f"language={language}, error={str(e)}"
)
# 继续创建其他类型
continue
return created_count

View File

@@ -1,9 +1,12 @@
import uuid
import io
from typing import Optional, Annotated
from fastapi import APIRouter, Depends, Path
import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from urllib.parse import quote
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
@@ -17,12 +20,14 @@ from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_dsl_service import AppDslService
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -65,7 +70,7 @@ def list_apps(
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
@@ -394,10 +399,10 @@ async def draft_run(
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
service = AppService(db)
draft_service = DraftRunService(db)
draft_service = AgentRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
@@ -482,8 +487,8 @@ async def draft_run(
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
@@ -787,8 +792,8 @@ async def draft_run_compare(
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
@@ -818,8 +823,8 @@ async def draft_run_compare(
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
@@ -833,7 +838,8 @@ async def draft_run_compare(
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
timeout=payload.timeout or 60,
files=payload.files
)
logger.info(
@@ -879,6 +885,60 @@ async def update_workflow_config(
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.get("/{app_id}/workflow/export")
@cur_workspace_access_guard()
async def export_workflow_config(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""导出工作流配置为YAML文件"""
workflow_service = WorkflowService(db)
return success(data={
"content": workflow_service.export_workflow_dsl(app_id=app_id),
})
@router.post("/workflow/import")
@cur_workspace_access_guard()
async def import_workflow_config(
file: UploadFile = File(...),
platform: str = Form(...),
app_id: str = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从YAML内容导入工作流配置"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
raw_text = (await file.read()).decode("utf-8")
import_service = WorkflowImportService(db)
config = yaml.safe_load(raw_text)
result = await import_service.upload_config(platform, config)
return success(data=result)
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
import_service = WorkflowImportService(db)
app = await import_service.save_workflow(
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
temp_id=data.temp_id,
name=data.name,
description=data.description,
)
return success(data=app_schema.App.model_validate(app))
@router.get("/{app_id}/statistics", summary="应用统计数据")
@cur_workspace_access_guard()
def get_app_statistics(
@@ -889,12 +949,14 @@ def get_app_statistics(
current_user=Depends(get_current_user),
):
"""获取应用统计数据
Args:
app_id: 应用ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
- daily_conversations: 每日会话数统计
- total_conversations: 总会话数
@@ -931,6 +993,8 @@ def get_workspace_api_statistics(
Args:
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
每日统计数据列表,每项包含:
@@ -949,3 +1013,57 @@ def get_workspace_api_statistics(
)
return success(data=result)
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
@cur_workspace_access_guard()
async def export_app(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
release_id: Optional[uuid.UUID] = None
):
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
release_id: 指定发布版本id不传则导出当前草稿配置。
"""
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
encoded = quote(filename, safe=".")
yaml_bytes = yaml_str.encode("utf-8")
file_stream = io.BytesIO(yaml_bytes)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/octet-stream; charset=utf-8",
headers={"Content-Disposition": f"attachment; filename={encoded}",
"Content-Length": str(len(yaml_bytes))}
)
@router.post("/import", summary="从 YAML 文件导入应用")
@cur_workspace_access_guard()
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
raw = (await file.read()).decode("utf-8")
dsl = yaml.safe_load(raw)
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -55,6 +55,12 @@ async def get_mcp_servers(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
if page * pagesize > 100:
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
)
# 2. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
@@ -64,14 +70,16 @@ async def get_mcp_servers(
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
return success(msg='The mcp market config does not exist or access is denied')
# 3. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
body = {
@@ -90,7 +98,7 @@ async def get_mcp_servers(
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
api_logger.error(f"Failed to get MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get MCP servers: {str(e)}"
@@ -118,6 +126,67 @@ async def get_mcp_servers(
return success(data=result, msg="Query of mcp servers list successful")
@router.get("/operational_mcp_servers", response_model=ApiResponse)
async def get_operational_mcp_servers(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the operational mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + operational mcp server list
"""
api_logger.info(
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 2. Execute paged query
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
r = api.session.get(url, headers=headers, cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get operational MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 3. Return structured response
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
@router.get("/mcp_server", response_model=ApiResponse)
async def get_mcp_server(
mcp_market_config_id: uuid.UUID,
@@ -139,14 +208,16 @@ async def get_mcp_server(
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
return success(msg='The mcp market config does not exist or access is denied')
# 2. Get detailed information for a specific MCP Server
api = MCPApi()
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
result = api.get_mcp_server(server_id=server_id)
@@ -167,7 +238,26 @@ async def create_mcp_market_config(
try:
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
# 1. Check if the mcp market name already exists
# 1. Validate token can access ModelScope MCP market
if not create_data.token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Token is required to access ModelScope MCP market"
)
try:
api = MCPApi()
api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 2. Check if the mcp market name already exists
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
if db_mcp_market_config_exist:
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
@@ -203,10 +293,7 @@ async def get_mcp_market_config(
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
return success(msg='The mcp market config does not exist or access is denied')
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
@@ -236,10 +323,7 @@ async def get_mcp_market_config_by_mcp_market_id(
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
return success(msg='The mcp market config does not exist or access is denied')
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
@@ -265,12 +349,25 @@ async def update_mcp_market_config(
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
return success(msg='The mcp market config does not exist or access is denied')
# 2. Update fields (only update non-null fields)
# 2. Validate new token if provided
if update_data.token is not None:
try:
api = MCPApi()
api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
update_dict = update_data.dict(exclude_unset=True)
updated_fields = []
@@ -285,7 +382,7 @@ async def update_mcp_market_config(
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
# 4. Save to database
try:
db.commit()
db.refresh(db_mcp_market_config)
@@ -322,10 +419,7 @@ async def delete_mcp_market_config(
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
return success(msg='The mcp market config does not exist or access is denied')
# 2. Deleting mcp market config
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)

View File

@@ -1,27 +1,29 @@
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv()
api_logger = get_api_logger()
@@ -36,7 +38,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
@@ -54,8 +56,9 @@ async def get_health_status(
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
@@ -74,16 +77,16 @@ async def download_log(
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
@@ -118,10 +121,10 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
@@ -135,11 +138,11 @@ async def write_server(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -148,7 +151,7 @@ async def write_server(
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
@@ -160,13 +163,15 @@ async def write_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
@@ -174,7 +179,7 @@ async def write_server(
messages_list,
config_id,
db,
storage_type,
storage_type,
user_rag_memory_id,
language
)
@@ -194,10 +199,10 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
@@ -212,10 +217,11 @@ async def write_server_async(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
@@ -243,7 +249,7 @@ async def write_server_async(
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
@@ -253,9 +259,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
@@ -290,8 +296,9 @@ async def read_server(
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.end_user_id,
@@ -305,7 +312,8 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
@@ -318,7 +326,7 @@ async def read_server(
db=db
)
if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info
result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -334,9 +342,10 @@ async def read_server(
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"),
model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
文件上传接口 - 支持图片识别
@@ -349,9 +358,6 @@ async def file_update(
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
@@ -360,7 +366,7 @@ async def file_update(
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
@@ -374,12 +380,12 @@ async def file_update(
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@@ -429,8 +435,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
@@ -451,7 +457,7 @@ async def get_read_task_result(
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -469,7 +475,7 @@ async def get_read_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -478,7 +484,7 @@ async def get_read_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -498,7 +504,7 @@ async def get_read_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -506,8 +512,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
@@ -528,7 +534,7 @@ async def get_write_task_result(
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -546,7 +552,7 @@ async def get_write_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -555,7 +561,7 @@ async def get_write_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -575,7 +581,7 @@ async def get_write_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -583,9 +589,9 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
@@ -628,9 +634,10 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -639,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
@@ -654,48 +656,70 @@ async def get_knowledge_type_stats_api(
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user),
db: Session=Depends(get_db),
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
获取指定用户的热门记忆标签
获取指定用户的兴趣分布标签
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
返回格式:
[
{"name": "标签", "frequency": 频次},
{"name": "兴趣活动", "frequency": 频次},
...
]
"""
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
language = get_language_from_header(language_type)
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
# 优先读取缓存
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
limit=limit
language=language,
)
return success(data=result, msg="获取热门记忆标签成功")
if cached is not None:
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
return success(data=cached, msg="获取兴趣分布标签成功")
# 缓存未命中,调用模型生成
result = await memory_agent_service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=limit,
language=language
)
# 写入缓存24小时过期
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
)
return success(data=result, msg="获取兴趣分布标签成功")
except Exception as e:
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
api_logger.error(f"Interest distribution by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
@@ -733,17 +757,17 @@ async def get_user_profile_api(
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
@@ -759,9 +783,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
@@ -780,9 +804,9 @@ async def get_end_user_connected_config(
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
@@ -791,4 +815,4 @@ async def get_end_user_connected_config(
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from typing import Optional
from app.core.response_utils import success
@@ -149,6 +150,21 @@ async def get_workspace_end_users(
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
"app.tasks.init_implicit_emotions_for_users",
kwargs={"end_user_ids": end_user_ids},
)
_celery_app.send_task(
"app.tasks.init_interest_distribution_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
@@ -387,14 +403,15 @@ def get_current_user_rag_total_num(
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
page: int = Query(1, gt=0, description="页码从1开始"),
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主知识库中的chunk内容
获取当前宿主知识库中的chunk内容(分页)
"""
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功")
@@ -407,26 +424,18 @@ async def get_chunk_summary_tag(
current_user: User = Depends(get_current_user),
):
"""
获取chunk总结、提取的标签和人物形象
读取RAG摘要、标签和人物形象纯读库不触发生成
返回格式:
{
"summary": "chunk内容的总结",
"tags": [
{"tag": "标签1", "frequency": 5},
{"tag": "标签2", "frequency": 3},
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
"summary": "用户摘要",
"tags": [{"tag": "标签1", "frequency": 5}, ...],
"personas": ["产品设计师", ...],
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk摘要标签人物形象")
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG摘要/标签/人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id,
limit=limit,
@@ -434,9 +443,8 @@ async def get_chunk_summary_tag(
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
return success(data=data, msg="获取成功")
@router.get("/chunk_insight", response_model=ApiResponse)
@@ -447,24 +455,57 @@ async def get_chunk_insight(
current_user: User = Depends(get_current_user),
):
"""
获取chunk的洞察内容
读取RAG洞察报告纯读库不触发生成
返回格式:
{
"insight": "对chunk内容的深度洞察分析"
"insight": "总体概述",
"behavior_pattern": "行为模式",
"key_findings": "关键发现",
"growth_trajectory": "成长轨迹",
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk洞察")
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG洞察")
data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id,
limit=limit,
db=db,
current_user=current_user
)
api_logger.info("成功获取chunk洞察")
return success(data=data, msg="chunk洞察获取成功")
return success(data=data, msg="获取成功")
class GenerateRagProfileRequest(BaseModel):
end_user_id: str = Field(..., description="宿主ID")
limit: int = Field(15, description="参与生成的chunk数量上限")
max_tags: int = Field(10, description="最大标签数量")
@router.post("/generate_rag_profile", response_model=ApiResponse)
async def generate_rag_profile(
body: GenerateRagProfileRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
生产接口为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
每次请求都会重新生成,覆盖已有数据。
"""
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
data = await memory_dashboard_service.generate_rag_profile(
end_user_id=body.end_user_id,
limit=body.limit,
max_tags=body.max_tags,
db=db,
current_user=current_user,
)
api_logger.info(f"RAG画像生产完成: {data}")
return success(data=data, msg="RAG画像生产完成")
@router.get("/dashboard_data", response_model=ApiResponse)
@@ -606,8 +647,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, status,Header
from typing import Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy.orm import Session
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
@@ -29,11 +31,11 @@ async def short_term_configs(
language = get_language_from_header(language_type)
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_term=ShortService(end_user_id, db)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_term=LongService(end_user_id, db)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)

View File

@@ -2,7 +2,7 @@ from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
@@ -85,6 +85,7 @@ def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -99,7 +100,29 @@ def create_config(
svc = DataConfigService(db)
result = svc.create(payload)
return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -521,10 +544,11 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info("Recent activity stats requested")
) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
try:
result = await analytics_recent_activity_stats()
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")

View File

@@ -371,6 +371,11 @@ def update_model(
if model_data.type is not None or model_data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
if model_data.is_active:
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
if not active_keys:
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
@@ -469,7 +474,9 @@ async def create_model_api_key_by_provider(
config=api_key_data.config,
is_active=api_key_data.is_active,
priority=api_key_data.priority,
model_config_ids=model_config_ids
model_config_ids=model_config_ids,
capability=api_key_data.capability,
is_omni=api_key_data.is_omni
)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)

View File

@@ -25,13 +25,13 @@ from typing import Dict, Optional, List
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
logger = logging.getLogger(__name__)
router = APIRouter(
@@ -123,15 +124,23 @@ def _get_ontology_service(
)
# 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
from app.repositories.model_repository import ModelApiKeyRepository
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
# from app.repositories.model_repository import ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
if not api_key_config:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有可用的API密钥"
)
api_key_config = api_keys[0]
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
# if not api_keys:
# logger.error(f"Model {llm_id} has no active API key")
# raise HTTPException(
# status_code=400,
# detail="指定的LLM模型没有可用的API密钥"
# )
# api_key_config = api_keys[0]
is_composite = getattr(model_config, 'is_composite', False)
logger.info(
@@ -153,6 +162,7 @@ def _get_ontology_service(
provider=actual_provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
max_retries=3,
timeout=60.0
)
@@ -279,7 +289,8 @@ async def extract_ontology(
async def create_scene(
request: SceneCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
):
"""创建本体场景
@@ -350,8 +361,18 @@ async def create_scene(
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
err_str = str(e)
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
from app.core.language_utils import get_language_from_header
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
@@ -399,6 +420,20 @@ async def update_scene(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认场景
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(scene_uuid)
if scene and scene.is_system_default:
business_logger.warning(
f"尝试修改系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认场景不可修改",
"该场景为系统预设场景,不允许修改"
)
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
@@ -491,6 +526,19 @@ async def delete_scene(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认场景
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(scene_uuid)
if scene and scene.is_system_default:
business_logger.warning(
f"尝试删除系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
raise HTTPException(
status_code=400,
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
)
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
@@ -514,6 +562,9 @@ async def delete_scene(
return success(data={"deleted": success_flag}, msg="场景删除成功")
except HTTPException:
raise
except ValueError as e:
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
@@ -621,7 +672,8 @@ async def get_scenes(
async def create_class(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
):
"""创建本体类型
@@ -636,7 +688,7 @@ async def create_class(
ApiResponse: 包含创建的类型信息
"""
from app.controllers.ontology_secondary_routes import create_class_handler
return await create_class_handler(request, db, current_user)
return await create_class_handler(request, db, current_user, x_language_type)
@router.put("/class/{class_id}", response_model=ApiResponse)

View File

@@ -7,11 +7,11 @@
from uuid import UUID
from typing import Optional
from fastapi import Depends
from fastapi import Depends, Header
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.repositories.ontology_class_repository import OntologyClassRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
@@ -56,7 +58,7 @@ async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
pagesize: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -69,14 +71,14 @@ async def scenes_handler(
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效)
pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
)
try:
@@ -103,13 +105,13 @@ async def scenes_handler(
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
@@ -117,17 +119,15 @@ async def scenes_handler(
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
if page is not None and pagesize is not None:
start_idx = (page - 1) * pagesize
end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
@@ -139,17 +139,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
@@ -163,28 +162,25 @@ async def scenes_handler(
)
else:
# 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size)
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
@@ -196,17 +192,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
@@ -236,7 +231,8 @@ async def scenes_handler(
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
@@ -269,8 +265,11 @@ async def create_class_handler(
]
if count == 1:
# 单个创建
# 单个创建 - 先检查重名
class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
@@ -328,12 +327,36 @@ async def create_class_handler(
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
err_str = str(e)
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
err_str = str(e)
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
@@ -366,6 +389,20 @@ async def update_class_handler(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试修改系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可修改",
"该类型为系统预设类型,不允许修改"
)
# 创建Service
service = _get_dummy_ontology_service(db)
@@ -429,6 +466,20 @@ async def delete_class_handler(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试删除系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可删除",
"该类型为系统预设类型,不允许删除"
)
# 创建Service
service = _get_dummy_ontology_service(db)
@@ -585,6 +636,7 @@ async def classes_handler(
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items
)

View File

@@ -2,25 +2,32 @@ import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.core.response_utils import success, fail
from app.db import get_db, get_db_read
from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
@@ -206,15 +213,13 @@ def list_conversations(
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
share, release = service.get_release_by_share_token(share_data.share_token, password)
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
@@ -293,19 +298,15 @@ async def chat(
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service._get_release_by_share_token(share_token, password)
share, release = service.get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
@@ -318,7 +319,6 @@ async def chat(
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app仅查询未删除的应用
from app.models.app_model import App
app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
@@ -359,12 +359,12 @@ async def chat(
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == "agent":
if app_type == AppType.AGENT:
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == "multi_agent":
elif app_type == AppType.MULTI_AGENT:
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
@@ -638,6 +638,34 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@router.get("/config", summary="获取应用启动配置")
async def config_query(
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
share_service = SharedChatService(db)
share_token = share_data.share_token
share, release = share_service.get_release_by_share_token(share_token, password)
if release.app.type == AppType.WORKFLOW:
workflow_service = WorkflowService(db)
content = {
"app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config)
}
elif release.app.type == AppType.AGENT:
content = {
"app_type": release.app.type,
"variables": release.config.get("variables")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {
"app_type": release.app.type,
"variables": []
}
else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
return success(data=content)

View File

@@ -89,7 +89,6 @@ async def chat(
body = await request.json()
payload = AppChatRequest(**body)
other_id = payload.user_id
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id
workspace_id = app.workspace_id
@@ -135,7 +134,8 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
user_id=end_user_id,
is_draft=False
is_draft=False,
conversation_id=payload.conversation_id
)
if app_type == AppType.AGENT:
@@ -249,6 +249,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -14,6 +14,7 @@ from app.models import User
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
from app.core.exceptions import BusinessException
router = APIRouter(prefix="/tools", tags=["Tool System"])
@@ -97,7 +98,13 @@ async def create_tool(
):
"""创建工具"""
try:
tool_id = service.create_tool(
# 将 MCP 来源字段合并进 config
if request.tool_type == ToolType.MCP:
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = await service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
@@ -107,6 +114,8 @@ async def create_tool(
tags=request.tags
)
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except BusinessException as e:
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
@@ -95,16 +95,29 @@ def get_workspaces(
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""创建新的工作空间"""
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
from app.core.language_utils import get_language_from_header
# 验证并获取语言参数
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
f"language={language}"
)
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user)
db=db, workspace=workspace, user=current_user, language=language
)
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
api_logger.info(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}"
)
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")

View File

@@ -11,35 +11,37 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
is_omni: bool = False,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
):
"""初始化 LangChain Agent
@@ -60,12 +62,13 @@ class LangChainAgent:
self.provider = provider
self.tools = tools or []
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
self.last_tool_called: Optional[str] = None
# 根据工具数量动态调整最大迭代次数
# 基础值 + 每个工具额外的调用机会
if max_iterations is None:
@@ -73,9 +76,9 @@ class LangChainAgent:
self.max_iterations = 5 + len(self.tools) * 2
else:
self.max_iterations = max_iterations
self.system_prompt = system_prompt or "你是一个专业的AI助手"
logger.debug(
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
f"tool_count={len(self.tools)}, "
@@ -89,6 +92,7 @@ class LangChainAgent:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
@@ -143,21 +147,22 @@ class LangChainAgent:
"""
from langchain_core.tools import StructuredTool
from functools import wraps
wrapped_tools = []
for original_tool in tools:
tool_name = original_tool.name
original_func = original_tool.func if hasattr(original_tool, 'func') else None
if not original_func:
# 如果无法获取原始函数,直接使用原工具
wrapped_tools.append(original_tool)
continue
# 创建包装函数
def make_wrapped_func(tool_name, original_func):
"""创建包装函数的工厂函数,避免闭包问题"""
@wraps(original_func)
def wrapped_func(*args, **kwargs):
"""包装后的工具函数,跟踪连续调用次数"""
@@ -168,13 +173,13 @@ class LangChainAgent:
# 切换到新工具,重置计数器
self.tool_call_counter[tool_name] = 1
self.last_tool_called = tool_name
current_count = self.tool_call_counter[tool_name]
logger.debug(
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
)
# 检查是否超过最大连续调用次数
if current_count > self.max_tool_consecutive_calls:
logger.warning(
@@ -185,12 +190,12 @@ class LangChainAgent:
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
)
# 调用原始工具函数
return original_func(*args, **kwargs)
return wrapped_func
# 使用 StructuredTool 创建新工具
wrapped_tool = StructuredTool(
name=original_tool.name,
@@ -198,17 +203,17 @@ class LangChainAgent:
func=make_wrapped_func(tool_name, original_func),
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
) -> List[BaseMessage]:
"""准备消息列表
@@ -248,7 +253,7 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -261,23 +266,26 @@ class LangChainAgent:
List[Dict]: 消息内容列表
"""
# 根据 provider 使用不同的文本格式
if self.provider.lower() in ["bedrock", "anthropic"]:
# Anthropic/Bedrock: {"type": "text", "text": "..."}
content_parts = [{"type": "text", "text": text}]
else:
# 通义千问等: {"text": "..."}
content_parts = [{"text": text}]
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
# ModelProvider.GPUSTACK] or (
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
# content_parts = [{"type": "text", "text": text}]
# else:
# # 通义千问等: {"text": "..."}
# content_parts = [{"type": "text", "text": text}]
content_parts = [{"type": "text", "text": text}]
# 添加文件内容
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
content_parts.extend(files)
logger.debug(
f"构建多模态消息: provider={self.provider}, "
f"parts={len(content_parts)}, "
f"files={len(files)}"
)
return content_parts
async def chat(
@@ -302,7 +310,7 @@ class LangChainAgent:
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat= message
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
@@ -322,8 +330,8 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -367,14 +375,14 @@ class LangChainAgent:
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
logger.debug(f"AI 消息内容: {msg.content}")
# 处理多模态响应content 可能是字符串或列表
if isinstance(msg.content, str):
content = msg.content
@@ -407,12 +415,13 @@ class LangChainAgent:
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -439,16 +448,16 @@ class LangChainAgent:
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
"""执行流式对话
@@ -482,7 +491,6 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
@@ -500,13 +508,13 @@ class LangChainAgent:
full_content = ''
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
@@ -540,7 +548,7 @@ class LangChainAgent:
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
@@ -577,13 +585,13 @@ class LangChainAgent:
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
@@ -595,7 +603,8 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
@@ -609,5 +618,3 @@ class LangChainAgent:
logger.info("=" * 80)
logger.info("chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -1,9 +1,9 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Annotated, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
load_dotenv()
@@ -16,18 +16,18 @@ class Settings:
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -57,7 +57,6 @@ class Settings:
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
@@ -91,7 +90,7 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
@@ -115,6 +114,7 @@ class Settings:
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
@@ -130,7 +130,7 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
@@ -190,8 +190,12 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
# SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
@@ -201,21 +205,30 @@ class Settings:
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Periodic Task Schedule Configuration
# workspace_reflection: 每隔多少秒执行一次
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))
# forgetting_cycle: 每隔多少小时执行一次
FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))
# implicit_emotions_update: 每天几点执行小时0-23
# Celery Beat Schedule Configuration (定时任务执行频率)
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
# implicit_emotions_update: 每天几分执行分钟0-59
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal)
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
@@ -232,27 +245,28 @@ class Settings:
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
# workflow config
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
# ========================================================================
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
# Prompt 中最大类型数量
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
# 核心通用类型列表(逗号分隔)
CORE_GENERAL_TYPES: str = os.getenv(
"CORE_GENERAL_TYPES",
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
)
# 实验模式开关(允许通过 API 动态切换本体配置)
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"

View File

@@ -1,10 +1,10 @@
import os
import json
import os
import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
with get_db_context() as db_session:
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
@@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error_type": type(e).__name__,
"error_message": str(e),
"content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
with get_db_context() as db_session:
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
@@ -220,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
"error_type": type(e).__name__,
"error_message": str(e),
"questions_count": len(databasets),
"llm_model_id": memory_config.llm_model_id if memory_config else None
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Problem_Extension error details: {error_details}")

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state):
@@ -50,10 +45,12 @@ async def rag_config(state):
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception :
retrieval_knowledge=[]
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState:
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
# 收集所有内容
content_list = []
# 处理重排序结果
reranked = results.get('reranked_results', {})
if reranked:
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
end_user_id=state.get('end_user_id', '')
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
for key,values in problem_extension.items():
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
async def process_question_nodes(idx, question):
try:
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
send_verify = []
for i, j in zip(keys, val, strict=False):
if j!=['']:
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases}
return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
import time
start=time.time()
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
)
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os
import time
@@ -17,33 +15,77 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
summary_service = SummaryNodeService()
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数包含更好的错误处理和数据验证
"""
data = state.get("data", '')
# 构建系统提示词
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
@@ -62,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
try:
# 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
if structured is None:
logger.warning(f"LLM返回None使用默认回答")
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
@@ -82,18 +125,18 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning(f"结构化响应缺少data字段")
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
return aimessages
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
try:
logger.info("尝试非结构化输出作为fallback")
@@ -103,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
system_prompt=system_prompt,
fallback_message="信息不足,无法回答"
)
if response and response.strip():
# 简单清理响应
cleaned_response = response.strip()
@@ -111,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response
else:
return "信息不足,无法回答"
except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
@@ -132,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = {
"status": "success",
"summary_result": aimessages,
@@ -152,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
"user_rag_memory_id": user_rag_memory_id
}
}
retrieve={
retrieve = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"title":"快速检索",
"title": "快速检索",
"summary": aimessages,
"query": data,
"storage_type": storage_type,
@@ -167,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
}
}
return input_summary,retrieve
return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState:
start=time.time()
storage_type=state.get("storage_type",'')
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
end_user_id=state.get("end_user_id", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
history = await summary_history(state)
search_params = {
"end_user_id": end_user_id,
"question": data,
@@ -186,12 +233,14 @@ async def Input_Summary(state: ReadState) -> ReadState:
}
try:
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
# 'input_summary',RetrieveSummaryResponse)
@@ -199,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0]
except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True )
summary= {
logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary = {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
@@ -213,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary":summary}
return {"summary": summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '')
history = await summary_history( state)
async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
with open("检索.json","w",encoding='utf-8') as f:
with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", [])
start=time.time()
retrieve_info_str=[]
retrieve = retrieve.get("Expansion_issue", [])
start = time.time()
retrieve_info_str = []
for data in retrieve:
if data=='':
retrieve_info_str=''
if data == '':
retrieve_info_str = ''
else:
for key, value in data.items():
if key=='Answer_Small':
if key == 'Answer_Small':
for i in value:
retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str)
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str,
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -248,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
return {"summary": summary}
async def Summary(state: ReadState)-> ReadState:
start=time.time()
async def Summary(state: ReadState) -> ReadState:
start = time.time()
query = state.get("data", '')
verify=state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '')
retrieve_info_str=''
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key=='answer_small':
if key == 'answer_small':
for i in value:
retrieve_info_str+=i+'\n'
history=await summary_history(state)
retrieve_info_str += i + '\n'
history = await summary_history(state)
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages=await summary_llm(state,history,data,
'summary_prompt.jinja2','summary',SummaryResponse,0)
aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
@@ -289,11 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
return {"summary": summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '')
async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
@@ -309,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= {
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
return {"summary":result}
return {"summary": result}

View File

@@ -1,8 +1,9 @@
import asyncio
import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# 将 VerificationItem 对象转换为字典列表
verified_data = []
if messages_deal.expansion_issue:
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
verified_data.append(item.model_dump())
elif isinstance(item, dict):
verified_data.append(item)
Verify_result = {
"status": messages_deal.split_result,
"verified_data": verified_data,
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
}
}
return Verify_result
async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
json_schema=json_schema
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
split_result="failed",
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}
}

View File

@@ -1,3 +1,4 @@
from app.cache.memory.interest_memory import InterestMemoryCache
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id=end_user_id,
language=lang,
)
if deleted:
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
write_result = {
"status": "success",
"data": structured_messages,

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
)
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
@@ -49,7 +47,7 @@ async def make_read_graph():
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
@@ -62,20 +60,20 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
graph = workflow.compile()
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
@@ -92,17 +90,19 @@ async def main():
service_name="MemoryAgentService"
)
import time
start=time.time()
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
@@ -110,7 +110,7 @@ async def main():
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
@@ -125,23 +125,22 @@ async def main():
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
@@ -161,17 +160,20 @@ async def main():
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end=time.time()
print(100*'y')
print(f"总耗时: {end-start}s")
print(100*'y')
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
config_id: Configuration ID for processing (used to load pruning config)
Returns:
List of DialogData objects with generated chunks
@@ -57,6 +57,63 @@ async def get_chunked_dialogs(
end_user_id=end_user_id,
config_id=config_id
)
# 语义剪枝步骤(在分块之前)
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 加载剪枝配置
pruning_config = None
if config_id:
try:
with get_db_context() as db:
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="semantic_pruning"
)
if memory_config:
pruning_config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_classes=memory_config.ontology_classes,
)
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
# 获取LLM客户端用于剪枝
if pruning_config.pruning_switch:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
pruned_dialogs = await pruner.prune_dataset([dialog_data])
if pruned_dialogs:
dialog_data = pruned_dialogs[0]
remaining_msg_count = len(dialog_data.context.msgs)
deleted_count = original_msg_count - remaining_msg_count
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
else:
logger.warning("[剪枝] prune_dataset 返回空列表")
else:
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
except Exception as e:
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
except Exception as e:
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -225,5 +225,24 @@ async def write(
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
"temporal_count": 0,
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -1,9 +1,12 @@
import asyncio
import json
import logging
import os
from typing import List, Tuple
from app.core.config import settings
logger = logging.getLogger(__name__)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
class InterestTags(BaseModel):
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
return structured_response.meaningful_tags
except Exception as e:
print(f"LLM筛选过程中发生错误: {e}")
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
# 在LLM失败时返回原始标签确保流程继续
return tags
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
"""
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
Args:
tags: 原始标签列表
end_user_id: 用户ID用于获取LLM配置
Returns:
筛选后的兴趣活动标签列表
"""
try:
with get_db_context() as db:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if not config_id and not workspace_id:
raise ValueError(
f"No memory_config_id found for end_user_id: {end_user_id}."
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
if not memory_config.llm_model_id:
raise ValueError(
f"No llm_model_id found in memory config {config_id}."
)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
tag_list_str = ", ".join(tags)
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
messages = [
{
"role": "user",
"content": rendered_prompt
}
]
structured_response = await llm_client.response_structured(
messages=messages,
response_model=InterestTags
)
return structured_response.interest_tags
except Exception as e:
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
return tags
async def get_raw_tags_from_db(
connector: Neo4jConnector,
end_user_id: str,
@@ -139,14 +210,14 @@ async def get_raw_tags_from_db(
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
查询更多的标签(40)给LLM提供更丰富的上下文进行筛选但最终返回数量由limit参数控制
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
@@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
# 1. 从数据库获取原始排名靠前的标签查询40条给LLM提供更丰富的上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
@@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
if tag in meaningful_tag_names:
final_tags.append((tag, freq))
return final_tags
# 4. 限制返回的标签数量
return final_tags[:limit]
finally:
# 确保关闭连接
await connector.close()
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
"""
获取用户的兴趣分布标签。
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
ValueError: 如果end_user_id未提供或为空
"""
if not end_user_id or not end_user_id.strip():
raise ValueError(
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
connector = Neo4jConnector()
try:
# 查询更多原始标签给LLM提供充足上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
# 使用兴趣活动专用prompt进行筛选支持语义推断出新标签
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
# 构建最终标签列表:
# - 原始标签中存在的,保留原始频率
# - LLM推断出的新标签不在原始列表中赋予默认频率1
final_tags = []
seen = set()
for tag in interest_tag_names:
if tag in seen:
continue
seen.add(tag)
freq = raw_freq_map.get(tag, 1)
final_tags.append((tag, freq))
# 按频率降序排列
final_tags.sort(key=lambda x: x[1], reverse=True)
return final_tags[:limit]
finally:
await connector.close()

View File

@@ -10,7 +10,7 @@ Classes:
TemporalSearchParams: Parameters for temporal search queries
"""
from typing import Optional
from typing import Optional, List
from pydantic import BaseModel, Field
@@ -55,17 +55,26 @@ class PruningConfig(BaseModel):
Attributes:
pruning_switch: Enable or disable semantic pruning
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
pruning_scene: Scene name for pruning, either a built-in key
('education', 'online_service', 'outbound') or a custom scene_name
from ontology_scene table
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
scene_id: Optional ontology scene UUID, used to load custom ontology classes
ontology_classes: List of class_name strings from ontology_class table,
injected into the prompt when pruning_scene is not a built-in scene
"""
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
pruning_scene: str = Field(
"education",
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
)
pruning_threshold: float = Field(
0.5, ge=0.0, le=0.9,
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
ontology_classes: Optional[List[str]] = Field(
None, description="Class names from ontology_class table for custom scenes."
)
class TemporalSearchParams(BaseModel):

View File

@@ -5,20 +5,27 @@
- 对话级一次性抽取判定相关性
- 仅对"不相关对话"的消息按比例删除
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
"""
import asyncio
import os
import hashlib
import json
import re
from collections import OrderedDict
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Dict, Tuple, Set
from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
SceneConfigRegistry,
ScenePatterns
)
class DialogExtractionResponse(BaseModel):
@@ -26,6 +33,7 @@ class DialogExtractionResponse(BaseModel):
- is_related对话与场景的相关性判定。
- times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。
- preserve_keywords情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
"""
is_related: bool = Field(...)
times: List[str] = Field(default_factory=list)
@@ -34,6 +42,24 @@ class DialogExtractionResponse(BaseModel):
contacts: List[str] = Field(default_factory=list)
addresses: List[str] = Field(default_factory=list)
keywords: List[str] = Field(default_factory=list)
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
class MessageImportanceResponse(BaseModel):
"""消息重要性批量判断的结构化返回用于LLM语义判断
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
- reasons: 可选的判断理由
"""
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
class QAPair(BaseModel):
"""问答对模型,用于识别和保护对话中的问答结构。"""
question_idx: int = Field(..., description="问题消息的索引")
answer_idx: int = Field(..., description="答案消息的索引")
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
class SemanticPruner:
@@ -43,109 +69,280 @@ class SemanticPruner:
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
"""
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
cfg_dict = get_pruning_config() if config is None else config.model_dump()
self.config = PruningConfig.model_validate(cfg_dict)
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
# 如果没有提供config使用默认配置
if config is None:
# 使用默认的剪枝配置
config = PruningConfig(
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
pruning_scene="education",
pruning_threshold=0.5
)
self.config = config
self.llm_client = llm_client
self.language = language # 保存语言配置
self.max_concurrent = max_concurrent # 新增:最大并发数
# 详细日志配置:限制逐条消息日志的数量
self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载统一填充词库
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
# 本体类型列表(用于注入提示词,所有场景均支持)
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else:
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
self._cache_max_size = 1000 # 缓存大小限制
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
# 采用顺序处理,移除并发配置以简化与稳定执行
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
"""
import re
text = message.msg.strip()
if not text:
return False
patterns = [
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
r"\b\d{1,2}:\d{2}\b",
r"\d{4}\d{1,2}月\d{1,2}日",
r"上午|下午|AM|PM",
r"订单号|工单|申请号|编号|ID|账号|账户",
r"电话|手机号|微信|QQ|邮箱",
r"地址|地点",
r"金额|费用|价格|¥|¥|\d+元",
r"时间|日期|有效期|截止",
]
for p in patterns:
if re.search(p, text, flags=re.IGNORECASE):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
简单启发:匹配到的类别越多、越关键分值越高。
"""
import re
text = message.msg.strip()
score = 0
weights = [
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
(r"\b\d{1,2}:\d{2}\b", 2),
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
(r"电话|手机号|微信|QQ|邮箱", 3),
(r"地址|地点", 2),
(r"金额|费用|价格|¥|¥|\d+元", 4),
(r"时间|日期|有效期|截止", 2),
]
for p, w in weights:
if re.search(p, text, flags=re.IGNORECASE):
score += w
return score
# _is_important_message 和 _importance_score 已移除:
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息用于跳过LLM分类以加速
"""检测典型寒暄/口头禅/确认类短消息。
满足以下之一视为填充消息
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
判断顺序
1. 空消息
2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配
4. 纯表情/标点
"""
import re
t = message.msg.strip()
if not t:
return True
# 常见填充语
fillers = [
"你好", "您好", "在吗", "", "嗯嗯", "", "好的", "", "", "可以", "不可以", "谢谢",
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", ""
]
if t in fillers:
# 检查是否在场景特定填充词库中(精确匹配)
if t in self.scene_config.filler_phrases:
return True
# 长度与字符类型判断
if len(t) <= 8:
# 非数字、无关键实体的短文本
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
# 主要是标点或简单确认词
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
return True
# 常见寒暄和问候(精确匹配,避免误删)
common_greetings = {
"在吗", "在不在", "在呢", "在的",
"你好", "您好", "hello", "hi",
"拜拜", "再见", "", "88", "bye",
"好的", "", "", "可以", "", "", "",
"是的", "", "对的", "没错", "是啊",
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
}
if t in common_greetings:
return True
# 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t):
return True
# 纯标点符号
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True
return False
async def _batch_evaluate_importance_with_llm(
self,
messages: List[ConversationMessage],
context: str = ""
) -> Dict[int, int]:
"""使用LLM批量评估消息的重要性语义层面
Args:
messages: 消息列表
context: 对话上下文(可选)
Returns:
消息索引到重要性分数(0-10)的映射
"""
if not self.llm_client or not messages:
return {}
# 构建批量评估的提示词
msg_list = []
for idx, msg in enumerate(messages):
msg_list.append(f"{idx}. {msg.msg}")
msg_text = "\n".join(msg_list)
prompt = f"""请评估以下消息的重要性给每条消息打分0-10分
- 0-2分无意义的寒暄、口头禅、纯表情
- 3-5分一般性对话有一定信息量但不关键
- 6-8分包含重要信息时间、地点、人物、事件等
- 9-10分关键决策、承诺、重要数据
对话上下文:
{context if context else ""}
待评估的消息:
{msg_text}
请以JSON格式返回格式为
{{
"importance_scores": {{
"0": 分数,
"1": 分数,
...
}}
}}
"""
try:
messages_for_llm = [
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
{"role": "user", "content": prompt}
]
response = await self.llm_client.response_structured(
messages_for_llm,
MessageImportanceResponse
)
# 转换字符串键为整数键
return {int(k): v for k, v in response.importance_scores.items()}
except Exception as e:
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
return {}
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
"""识别对话中的问答对,用于保护问答结构的完整性。
改进版:使用场景特定的问句关键词,并排除寒暄类问句
Args:
messages: 消息列表
Returns:
问答对列表
"""
qa_pairs = []
# 寒暄类问句,不应该被保护(这些不是真正的问答)
greeting_questions = {
"在吗", "在不在", "你好吗", "怎么样", "好吗",
"有空吗", "忙吗", "睡了吗", "起床了吗"
}
for i in range(len(messages) - 1):
current_msg = messages[i].msg.strip()
next_msg = messages[i + 1].msg.strip()
# 排除寒暄类问句
if current_msg in greeting_questions:
continue
# 使用场景特定的问句关键词,但要求更严格
is_question = False
# 1. 以问号结尾
if current_msg.endswith("") or current_msg.endswith("?"):
is_question = True
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时"]):
is_question = True
if is_question and next_msg:
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
is_answer = not (next_msg.endswith("") or next_msg.endswith("?"))
# 排除寒暄类回复
greeting_answers = {"你好", "您好", "在呢", "在的", "", "", "好的"}
if next_msg in greeting_answers:
is_answer = False
if is_answer:
qa_pairs.append(QAPair(
question_idx=i,
answer_idx=i + 1,
confidence=0.8 # 基于规则的置信度
))
return qa_pairs
def _get_protected_indices(
self,
messages: List[ConversationMessage],
qa_pairs: List[QAPair],
window_size: int = 2
) -> Set[int]:
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
Args:
messages: 消息列表
qa_pairs: 问答对列表
window_size: 上下文窗口大小(前后各保留几条消息)
Returns:
需要保护的消息索引集合
"""
protected = set()
for qa_pair in qa_pairs:
# 保护问答对本身
protected.add(qa_pair.question_idx)
protected.add(qa_pair.answer_idx)
# 保护上下文窗口
for offset in range(-window_size, window_size + 1):
q_idx = qa_pair.question_idx + offset
a_idx = qa_pair.answer_idx + offset
if 0 <= q_idx < len(messages):
protected.add(q_idx)
if 0 <= a_idx < len(messages):
protected.add(a_idx)
return protected
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
- 仅使用 LLM 结构化输出;
改进版:
- LRU缓存管理
- 重试机制
- 降级策略
"""
# 缓存命中则直接返回(场景+内容作为键)
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
# LRU缓存如果命中移到末尾最近使用
if cache_key in self._dialog_extract_cache:
self._dialog_extract_cache.move_to_end(cache_key)
return self._dialog_extract_cache[cache_key]
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
# LRU缓存大小限制超过限制时删除最旧的条目
if len(self._dialog_extract_cache) >= self._cache_max_size:
# 删除最旧的条目OrderedDict的第一个
oldest_key = next(iter(self._dialog_extract_cache))
del self._dialog_extract_cache[oldest_key]
self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目")
rendered = self.template.render(
pruning_scene=self.config.pruning_scene,
ontology_classes=self._ontology_classes,
dialog_text=dialog_text,
language=self.language
)
log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene,
"ontology_classes_count": len(self._ontology_classes),
"language": self.language
})
log_prompt_rendering("pruning-extract", rendered)
# 强制使用 LLM;移除正则回退
# 强制使用 LLM
if not self.llm_client:
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
@@ -153,12 +350,32 @@ class SemanticPruner:
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
{"role": "user", "content": rendered},
]
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
if attempt < max_retries - 1:
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
continue
else:
# 降级策略:标记为相关,避免误删
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
fallback_response = DialogExtractionResponse(
is_related=True,
times=[],
ids=[],
amounts=[],
contacts=[],
addresses=[],
keywords=[]
)
return fallback_response
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。"""
@@ -184,62 +401,56 @@ class SemanticPruner:
# 相关对话不剪枝
return dialog
# 在不相关对话中,识别重要/不重要消息
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords
)
msgs = dialog.context.msgs
imp_unrel_msgs: List[ConversationMessage] = []
unimp_unrel_msgs: List[ConversationMessage] = []
# 分类:填充 / 其他可删LLM保护消息通过不加入任何桶来隐式保护
filler_ids: set = set()
deletable: List[ConversationMessage] = []
for m in msgs:
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
imp_unrel_msgs.append(m)
if self._msg_matches_tokens(m, preserve_tokens):
pass # 保护消息:不加入任何桶,不会被删除
elif self._is_filler_message(m):
filler_ids.add(id(m))
else:
unimp_unrel_msgs.append(m)
# 计算总删除目标数量
deletable.append(m)
# 计算删除目标
total_unrel = len(msgs)
delete_target = int(total_unrel * proportion)
if proportion > 0 and total_unrel > 0 and delete_target == 0:
delete_target = 1
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
unimp_del_cap = len(unimp_unrel_msgs)
max_capacity = max(0, len(msgs) - 1)
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
delete_target = min(delete_target, max_deletable)
# 删除配额分配
del_unimp = min(delete_target, unimp_del_cap)
rem = delete_target - del_unimp
del_imp = min(rem, imp_del_cap)
# 选取删除集合
unimp_delete_ids = []
imp_delete_ids = []
if del_unimp > 0:
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
if del_imp > 0:
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
# 统计实际删除数量(重要/不重要)
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept_msgs = []
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
# 优先删填充,再删其他可删消息(按出现顺序)
to_delete_ids: set = set()
for m in msgs:
mid = id(m)
if mid in delete_targets:
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
continue
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
actual_imp_deleted += 1
continue
kept_msgs.append(m)
if len(to_delete_ids) >= delete_target:
break
if id(m) in filler_ids:
to_delete_ids.add(id(m))
for m in deletable:
if len(to_delete_ids) >= delete_target:
break
to_delete_ids.add(id(m))
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
deleted_total = actual_unimp_deleted + actual_imp_deleted
deleted_total = len(msgs) - len(kept_msgs)
protected_count = len(msgs) - len(filler_ids) - len(deletable)
self._log(
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
)
dialog.context = ConversationContext(msgs=kept_msgs)
@@ -248,12 +459,14 @@ class SemanticPruner:
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
"""数据集层面:全局消息级剪枝,保留所有对话。
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
- 保证每段对话至少保留1条消息不会删除整段对话。
改进版:
- 消息级独立判断,每条消息根据场景规则独立评估
- 问答对保护已注释(暂不启用,留作观察)
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
- 保证每段对话至少保留1条消息不会删除整段对话
"""
# 如果剪枝功能关闭,直接返回原始数据集
# 如果剪枝功能关闭,直接返回原始数据集
if not self.config.pruning_switch:
return dialogs
@@ -264,179 +477,139 @@ class SemanticPruner:
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
)
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
evaluated_dialogs = []
for idx, dd in enumerate(dialogs):
try:
ex = await self._extract_dialog_important(dd.content)
evaluated_dialogs.append({
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
})
except Exception:
evaluated_dialogs.append({
"dialog": dd,
"is_related": True,
"index": idx,
"extraction": None
})
# 统计相关 / 不相关对话
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
self._log(
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
)
# 简洁打印第几段对话相关/不相关索引基于1
def _fmt_indices(items, cap: int = 10):
inds = [i["index"] + 1 for i in items]
if len(inds) <= cap:
return inds
# 超过上限时只打印前cap个并标注总数
return inds[:cap] + ["...", f"{len(inds)}"]
rel_inds = _fmt_indices(related_dialogs)
nrel_inds = _fmt_indices(not_related_dialogs)
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}")
result: List[DialogData] = []
if not_related_dialogs:
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM
per_dialog_info = {}
total_unrelated = 0
total_capacity = 0
for d in not_related_dialogs:
dd = d["dialog"]
extraction = d.get("extraction")
if extraction is None:
extraction = await self._extract_dialog_important(dd.content)
# 合并所有重要标记
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dd.context.msgs
# 分类消息
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
# 重要消息按重要性排序
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
info = {
"dialog": dd,
"total_msgs": len(msgs),
"unrelated_count": len(msgs),
"imp_ids_sorted": imp_sorted_ids,
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
}
per_dialog_info[d["index"]] = info
total_unrelated += info["unrelated_count"]
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
global_delete = int(total_unrelated * proportion)
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
global_delete = 1
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例)且至少保留1条消息
capacities = []
for d in not_related_dialogs:
idx = d["index"]
info = per_dialog_info[idx]
# 统计重要数量
imp_count = len(info["imp_ids_sorted"])
unimp_count = len(info["unimp_ids"])
imp_cap = int(imp_count * proportion)
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
capacities.append(cap)
total_capacity = sum(capacities)
if global_delete > total_capacity:
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
global_delete = total_capacity
total_original_msgs = 0
total_deleted_msgs = 0
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
alloc = []
for i, d in enumerate(not_related_dialogs):
idx = d["index"]
info = per_dialog_info[idx]
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
alloc.append(min(share, capacities[i]))
allocated = sum(alloc)
rem = global_delete - allocated
turn = 0
while rem > 0 and turn < 100000:
progressed = False
for i in range(len(not_related_dialogs)):
if rem <= 0:
break
if alloc[i] < capacities[i]:
alloc[i] += 1
rem -= 1
progressed = True
if not progressed:
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
semaphore = asyncio.Semaphore(self.max_concurrent)
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
async with semaphore:
try:
return await self._extract_dialog_important(dd.content)
except Exception as e:
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
return DialogExtractionResponse(is_related=True)
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
msgs = dd.context.msgs
original_count = len(msgs)
total_original_msgs += original_count
# 从 LLM 抽取结果中获取所有需要保留的 token
preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
)
# 判断是否需要详细日志
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
if extraction.preserve_keywords:
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
# 消息级分类LLM保护 / 填充 / 其他可删
llm_protected_msgs = [] # LLM 保护消息preserve_tokens 命中):绝对不可删除
filler_msgs = [] # 填充消息(优先删除)
deletable_msgs = [] # 其余消息(按比例删除)
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护LLM不可删")
elif self._is_filler_message(m):
filler_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
else:
deletable_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 可删")
# important_msgs 仅用于日志统计
important_msgs = llm_protected_msgs
# 计算删除配额
delete_target = int(original_count * proportion)
if proportion > 0 and original_count > 0 and delete_target == 0:
delete_target = 1
# 确保至少保留1条消息
max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable)
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
to_delete_indices = set()
deleted_details = []
# 第一步:删除填充消息
for idx, msg in filler_msgs:
if len(to_delete_indices) >= delete_target:
break
turn += 1
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
total_deleted_confirm = 0
for d in evaluated_dialogs:
dd = d["dialog"]
msgs = dd.context.msgs
original = len(msgs)
if d["is_related"]:
result.append(dd)
continue
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
if idx_in_unrel is None:
result.append(dd)
continue
quota = alloc[idx_in_unrel]
info = per_dialog_info[d["index"]]
# 计算本对话重要最多可删数量
imp_count = len(info["imp_ids_sorted"])
imp_del_cap = int(imp_count * proportion)
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
del_unimp = min(quota, len(unimp_delete_ids))
rem_quota = quota - del_unimp
# 再从重要里选低分优先的删除ID不超过 imp_del_cap
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
deleted_here = 0
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept = []
for m in msgs:
mid = id(m)
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
deleted_here += 1
continue
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
actual_imp_deleted += 1
deleted_here += 1
continue
kept.append(m)
if not kept and msgs:
kept = [msgs[0]]
dd.context.msgs = kept
total_deleted_confirm += deleted_here
self._log(
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
else:
# 全部相关:不执行剪枝
result = [d["dialog"] for d in evaluated_dialogs]
# 第二步:如果还需要删除,按出现顺序删可删消息
for idx, msg in deletable_msgs:
if len(to_delete_indices) >= delete_target:
break
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
# 执行删除
kept_msgs = []
for idx, m in enumerate(msgs):
if idx not in to_delete_indices:
kept_msgs.append(m)
# 确保至少保留1条
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
dd.context.msgs = kept_msgs
deleted_count = original_count - len(kept_msgs)
total_deleted_msgs += deleted_count
# 输出删除详情
if deleted_details:
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
for detail in deleted_details:
self._log(f" {detail}")
# ========== 问答对统计(已注释) ==========
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
# ========================================
self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
# 保存日志
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f:
@@ -448,6 +621,7 @@ class SemanticPruner:
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
def _log(self, msg: str) -> None:

View File

@@ -0,0 +1,66 @@
"""
场景特定配置 - 统一填充词库
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
本模块仅保留统一填充词库filler_phrases用于识别无意义寒暄/表情/口头禅。
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
"""
from typing import List, Set
from dataclasses import dataclass, field
@dataclass
class ScenePatterns:
"""场景特定的识别模式(仅保留填充词库)"""
filler_phrases: Set[str] = field(default_factory=set)
class SceneConfigRegistry:
"""场景配置注册表 - 所有场景共用统一填充词库"""
BASE_FILLERS: Set[str] = {
# 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
"拜拜", "再见", "88", "", "回见",
# 口头禅
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼",
# 确认词
"是的", "", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
# 服务类套话
"请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "欢迎下次咨询",
# 外呼套话
"", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
# 教育场景套话
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
# 标点和符号
"。。。", "...", "???", "", "!!!", "",
# 表情符号
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
# 网络用语
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
"emmm", "emm", "em", "mmp", "wtf", "omg",
}
@classmethod
def get_config(cls, scene: str = "") -> ScenePatterns:
"""所有场景统一返回同一份填充词库"""
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)

View File

@@ -1932,17 +1932,17 @@ def preprocess_data(
Returns:
经过清洗转换后的 DialogData 列表
"""
print("\n=== 数据预处理 ===")
logger.debug("=== 数据预处理 ===")
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
preprocessor = DataPreprocessor()
try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data
except Exception as e:
print(f"数据预处理过程中出现错误: {e}")
logger.error(f"数据预处理过程中出现错误: {e}")
raise
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
Returns:
带 chunks 的 DialogData 列表
"""
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
if not data:
raise ValueError("预处理数据为空,无法进行分块")
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None,
skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None,
) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: 输入数据路径
llm_client: LLM 客户端
skip_cleaning: 是否跳过数据清洗步骤默认False
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
Returns:
带 chunks 的 DialogData 列表
"""
print("\n=== 完整数据处理流程(包含预处理)===")
logger.debug("=== 完整数据处理流程(包含预处理)===")
if input_data_path is None:
input_data_path = os.path.join(
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
pruner = SemanticPruner(llm_client=llm_client)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
if pruning_config:
# 使用传入的配置
config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else:
# 使用默认配置(关闭剪枝)
config = None
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
pruner = SemanticPruner(config=config, llm_client=llm_client)
# 记录单对话场景下剪枝前的消息数量
single_dialog_original_msgs = None
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
print(
logger.debug(
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}"
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
)
else:
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
# 保存剪枝后的数据
try:
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
dp = DataPreprocessor(output_file_path=pruned_output_path)
dp.save_data(preprocessed_data, output_path=pruned_output_path)
except Exception as se:
print(f"保存剪枝结果失败:{se}")
logger.error(f"保存剪枝结果失败:{se}")
except Exception as e:
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
# 步骤3: 对话分块
return await get_chunked_dialogs_from_preprocessed(

View File

@@ -1,5 +1,7 @@
import os
from typing import Optional
from typing import Optional, List, Any
from enum import Enum
from pathlib import Path
from app.core.logging_config import get_memory_logger
from app.core.memory.models.message_models import DialogData, Chunk
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
logger = get_memory_logger(__name__)
class ChunkerStrategy(Enum):
"""Supported chunking strategies."""
RECURSIVE = "RecursiveChunker"
SEMANTIC = "SemanticChunker"
LATE = "LateChunker"
NEURAL = "NeuralChunker"
LLM = "LLMChunker"
@classmethod
def get_valid_strategies(cls) -> List[str]:
"""Get list of valid strategy names."""
return [strategy.value for strategy in cls]
class DialogueChunker:
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
@@ -17,23 +33,51 @@ class DialogueChunker:
of different chunking strategies to dialogue data.
"""
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
"""Initialize the DialogueChunker with a specific chunking strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
llm_client: LLM client instance (required for LLMChunker strategy)
Raises:
ValueError: If chunker_strategy is invalid or required parameters are missing
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Validate strategy
valid_strategies = ChunkerStrategy.get_valid_strategies()
if chunker_strategy not in valid_strategies:
raise ValueError(
f"Invalid chunker_strategy: '{chunker_strategy}'. "
f"Must be one of {valid_strategies}"
)
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
self.chunker_strategy = chunker_strategy
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
try:
# Load and validate configuration
chunker_config_dict = get_chunker_config(chunker_strategy)
if not chunker_config_dict:
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Initialize chunker client
if self.chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("llm_client is required for LLMChunker strategy")
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
except Exception as e:
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
raise
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
"""Process a dialogue by generating chunks and adding them to the DialogData object.
Args:
@@ -43,54 +87,125 @@ class DialogueChunker:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
ValueError: If dialogue is invalid or chunking fails
Exception: If chunking process encounters an error
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# Validate input
if not dialogue:
raise ValueError("dialogue cannot be None")
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
f"Context: {dialogue.context is not None}, "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
)
logger.info(
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
f"using strategy: {self.chunker_strategy}"
)
try:
# Generate chunks
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
return chunks
# Validate results
if not chunks or len(chunks) == 0:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs)}, "
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
logger.info(
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
)
return chunks
except ValueError:
# Re-raise validation errors
raise
except Exception as e:
logger.error(
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
exc_info=True
)
raise
def save_chunking_results(
self,
chunks: List[Chunk],
dialogue: DialogData,
output_path: Optional[str] = None,
preview_length: int = 100
) -> str:
"""Save the chunking results to a file and return the output path.
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output
chunks: List of Chunk objects to save
dialogue: The DialogData object that was processed
output_path: Optional path to save the output (defaults to current directory)
preview_length: Maximum length of content preview (default: 100)
Returns:
The path where the output was saved
Raises:
ValueError: If chunks or dialogue is invalid
IOError: If file writing fails
"""
if not output_path:
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
# Validate input
if not chunks:
raise ValueError("chunks list cannot be empty")
if not dialogue:
raise ValueError("dialogue cannot be None")
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
# Generate default output path if not provided
if not output_path:
output_dir = Path(__file__).parent.parent.parent
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
logger.info(f"Saving chunking results to: {output_path}")
try:
# Prepare output content
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
f"Generated {len(chunks)} chunks:",
""
]
for i, chunk in enumerate(chunks, 1):
content_preview = chunk.content[:preview_length] if chunk.content else ""
if len(chunk.content) > preview_length:
content_preview += "..."
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {content_preview}")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
output_lines.append("")
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
# Write to file
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
logger.info(f"Chunking results saved to: {output_path}")
return output_path
logger.info(f"Successfully saved chunking results to: {output_path}")
return output_path
except IOError as e:
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
raise

View File

@@ -327,7 +327,7 @@ class MultiOntologyParser:
Example:
>>> parser = MultiOntologyParser([
... "General_purpose_entity.ttl",
... "app/core/memory/ontology_services/General_purpose_entity.ttl",
... "domain_specific.owl"
... ])
>>> registry = parser.parse_all()

View File

@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
user_id: str,
entities: str,
statements: str,
language: str = "zh"
language: str = "zh",
user_display_name: str = None
) -> str:
"""
Renders the user summary prompt using the user_summary.jinja2 template.
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
entities: Core entities with frequency information
statements: Representative statement samples
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
Returns:
Rendered prompt content as string
"""
# 如果没有提供 user_display_name使用默认值
if user_display_name is None:
user_display_name = "该用户" if language == "zh" else "the user"
template = prompt_env.get_template("user_summary.jinja2")
rendered_prompt = template.render(
user_id=user_id,
entities=entities,
statements=statements,
language=language
language=language,
user_display_name=user_display_name
)
# 记录渲染结果到提示日志
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
'user_id': user_id,
'entities_len': len(entities),
'statements_len': len(statements),
'language': language
'language': language,
'user_display_name': user_display_name
})
return rendered_prompt
@@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt(
})
return rendered_prompt
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
"""
Renders the interest filter prompt using the interest_filter.jinja2 template.
Args:
tag_list: Comma-separated string of raw tags to filter
language: Output language ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("interest_filter.jinja2")
rendered_prompt = template.render(tag_list=tag_list, language=language)
log_prompt_rendering('interest filter', rendered_prompt)
return rendered_prompt

View File

@@ -1,6 +1,6 @@
{#
对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, dialog_text
输入pruning_scene, ontology_classes, dialog_text, language
输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
@@ -9,39 +9,71 @@
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等
- addresses: [string],地址/地点相关文本
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
要求:
- 必须只输出上述 JSON且键名一致不得输出解释、前后缀不得包含注释。
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
- 仅输出上述键;避免多余解释或字段。
#}
{% set scene_instructions = {
'education': {
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
},
'online_service': {
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
},
'outbound': {
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
}
} %}
{% set scene_key = pruning_scene %}
{% if scene_key not in scene_instructions %}
{% set scene_key = 'education' %}
{# ── 确定场景说明 ── #}
{% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %}
{% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %}
{% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %}
{% else %}
{% if language == 'en' %}
{% set custom_types_str = '' %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% else %}
{% set custom_types_str = '' %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %}
{% endif %}
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
{% if language == "zh" %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性
场景说明:{{ instruction }}
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务
1. 判断对话是否与指定场景相关;
2. 从对话中抽取所有需要保留的重要信息片段。
场景说明:{{ instruction }}
{% if custom_types_str %}
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }}相关的内容即判定为相关is_related=true
{% endif %}
---
【必须保留的内容(不可删除)】
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
- 时间信息:日期、时间点、时间段、有效期 → times 字段
- 编号信息学号、工号、订单号、申请号、账号、ID → ids 字段
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
- 联系方式电话、手机号、邮箱、微信、QQ → contacts 字段
- 地址信息:地点、地址、位置 → addresses 字段
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
【可以删除的内容】
以下类型的内容属于低价值信息,可以在剪枝时删除:
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
例如:
- "我好开心呀" → 包含情绪开心必须保留preserve_keywords 中加入"开心"
- "好喜欢打羽毛球呀" → 包含兴趣爱好喜欢打羽毛球必须保留preserve_keywords 中加入"喜欢打羽毛球"
- "我好难过" → 包含情绪难过必须保留preserve_keywords 中加入"难过"
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪必须保留preserve_keywords 中加入"开心"
---
对话全文:
"""
{{ dialog_text }}
@@ -55,12 +87,46 @@
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
"keywords": [<string>...],
"preserve_keywords": [<string>...]
}
{% else %}
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
Scenario Description: {{ instruction }}
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
1. Determine whether the dialogue is relevant to the specified scene;
2. Extract all important information fragments that must be preserved.
Scenario Description: {{ instruction }}
{% if custom_types_str %}
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
{% endif %}
---
[MUST PRESERVE (cannot be deleted)]
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
- Time information: dates, time points, durations, expiry dates → times field
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
- Address information: locations, addresses, places → addresses field
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
[CAN BE DELETED]
The following types of content are low-value and can be removed during pruning:
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
Examples:
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
---
Full Dialogue:
"""
{{ dialog_text }}
@@ -74,6 +140,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
"keywords": [<string>...],
"preserve_keywords": [<string>...]
}
{% endif %}

View File

@@ -0,0 +1,67 @@
{% if language == "zh" %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
- '川味可视化', '川菜' → '烹饪'
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
- '吉他', '指弹', '琴谱' → '吉他'
- '跑步', '5公里', '跑鞋' → '跑步'
- '瑜伽垫', '瑜伽课' → '瑜伽'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
- If multiple tags all point to '攀岩', output '攀岩' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
**Output format**: Return a list of concise interest activity names in Chinese.
**Example**:
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
{% else %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
- 'morning meditation streak', 'mind-body peak' → 'meditation'
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
- 'open source project', 'data visualization tool', 'Python' → 'programming'
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
- 'running', '5km', 'running shoes' → 'running'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
**Output format**: Return a list of concise interest activity names in English.
**Example**:
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
{% endif %}

View File

@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% endif %}
===Inputs===
{% if user_id %}
- User ID: {{ user_id }}
{% if user_display_name %}
- User Display Name: {{ user_display_name }}
{% endif %}
{% if entities %}
- Core Entities & Frequency: {{ entities }}
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
3. Avoid excessive adjectives and empty phrases
4. Strictly follow the output format specified below
{% if language == "zh" %}
**【严格人称规定】**
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
- 绝对禁止使用用户ID如 {{ user_id }})来称呼用户
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA
{% else %}
**【STRICT PRONOUN RULES】**
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
{% endif %}
**Section-Specific Requirements:**
{% if language == "zh" %}
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% if language == "zh" %}
Example Input:
- User ID: user_12345
- User Display Name: 张三
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
Example Output:
【基本介绍】
我是张三一名充满热情的高级产品经理。在过去的5年里专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
张三一名充满热情的高级产品经理,在深圳工作。在过去的5年里张三专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
【性格特点】
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
@@ -121,13 +135,13 @@ Example Output:
"让每一个产品决策都充满温度。"
{% else %}
Example Input:
- User ID: user_12345
- User Display Name: John
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
Example Output:
【Basic Introduction】
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
【Personality Traits】
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.

View File

@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
T = TypeVar("T")
class RedBearModelConfig(BaseModel):
"""模型配置基类"""
model_name: str
provider: str
api_key: str
base_url: Optional[str] = None
is_omni: bool = False # 是否为 Omni 模型
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用可通过环境变量 LLM_TIMEOUT 配置
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
# 最大重试次数 - 默认2次以避免过长等待可通过环境变量 LLM_MAX_RETRIES 配置
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
concurrency: int = 5 # 并发限流
concurrency: int = 5 # 并发限流
extra_params: Dict[str, Any] = {}
class RedBearModelFactory:
"""模型工厂类"""
@classmethod
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
# 打印供应商信息用于调试
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
import httpx
if not config.base_url:
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
timeout_config = httpx.Timeout(
timeout=config.timeout,
connect=60.0,
read=config.timeout,
write=60.0,
pool=10.0,
)
return {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -65,7 +89,7 @@ class RedBearModelFactory:
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
}
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
@@ -82,7 +106,7 @@ class RedBearModelFactory:
# region 从 base_url 或 extra_params 获取
from botocore.config import Config as BotoConfig
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
# Configure with increased connection pool
@@ -90,16 +114,16 @@ class RedBearModelFactory:
max_pool_connections=max_pool_connections,
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
)
# 标准化模型 ID自动转换简化名称为完整 Bedrock Model ID
model_id = normalize_bedrock_model_id(config.model_name)
params = {
"model_id": model_id,
"config": boto_config,
**config.extra_params
}
# 解析 API key (格式: access_key_id:secret_access_key)
if config.api_key and ":" in config.api_key:
access_key_id, secret_access_key = config.api_key.split(":", 1)
@@ -107,45 +131,52 @@ class RedBearModelFactory:
params["aws_secret_access_key"] = secret_access_key
elif config.api_key:
params["aws_access_key_id"] = config.api_key
# 设置 region
if config.base_url:
params["region_name"] = config.base_url
elif "region_name" not in params:
params["region_name"] = "us-east-1" # 默认区域
return params
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@classmethod
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
return {
return {
"model": config.model_name,
# "base_url": config.base_url,
"jina_api_key": config.api_key,
**config.extra_params
}
}
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
"""根据模型提供商获取对应的模型类"""
provider = config.provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI
return OpenAI
elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings
return OpenAIEmbeddings
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.embeddings import DashScopeEmbeddings
return DashScopeEmbeddings
return DashScopeEmbeddings
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_rerank_class(provider: str):
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
# from langchain_ollama import OllamaEmbeddings
# return OllamaEmbeddings
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -6,6 +6,8 @@ models:
description: AI21 Labs大语言模型completion生成模式256000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: bedrock
@@ -15,6 +17,9 @@ models:
description: Amazon Nova大语言模型支持智能体思考、工具调用、流式工具调用、视觉能力300000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -28,6 +33,9 @@ models:
description: Anthropic Claude大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -42,6 +50,8 @@ models:
description: Cohere大语言模型支持智能体思考、工具调用、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -54,6 +64,9 @@ models:
description: DeepSeek大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -67,6 +80,8 @@ models:
description: Meta Llama大语言模型支持智能体思考、工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -78,6 +93,8 @@ models:
description: Mistral AI大语言模型支持智能体思考、工具调用32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -89,6 +106,8 @@ models:
description: OpenAI大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -101,6 +120,8 @@ models:
description: Qwen大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -113,6 +134,8 @@ models:
description: amazon.rerank-v1:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -122,6 +145,8 @@ models:
description: cohere.rerank-v3-5:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -131,6 +156,9 @@ models:
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型支持视觉能力8192上下文窗口
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 文本嵌入模型
- vision
@@ -141,6 +169,8 @@ models:
description: amazon.titan-embed-text-v1文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -150,6 +180,8 @@ models:
description: amazon.titan-embed-text-v2:0文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -159,6 +191,8 @@ models:
description: Cohere Embed 3 English文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -168,6 +202,8 @@ models:
description: Cohere Embed 3 Multilingual文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
logo: bedrock

View File

@@ -6,6 +6,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -16,6 +18,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -26,6 +30,8 @@ models:
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -36,6 +42,8 @@ models:
description: DeepSeek-V3.1大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -46,6 +54,8 @@ models:
description: DeepSeek-V3.2-exp实验版大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -56,6 +66,8 @@ models:
description: DeepSeek-V3.2大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -66,6 +78,8 @@ models:
description: DeepSeek-V3大语言模型支持智能体思考64000上下文窗口对话模式支持文本与JSON格式输出
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -76,6 +90,8 @@ models:
description: farui-plus大语言模型支持多工具调用、智能体思考、流式工具调用12288上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +104,8 @@ models:
description: GLM-4.7大语言模型支持多工具调用、智能体思考、流式工具调用202752超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +118,9 @@ models:
description: qvq-max-latest大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -112,6 +133,9 @@ models:
description: qvq-max大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -124,6 +148,8 @@ models:
description: qwen-coder-turbo-0919代码专用大语言模型支持智能体思考131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -135,6 +161,8 @@ models:
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -147,6 +175,8 @@ models:
description: qwen-max-longcontext长上下文大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -159,6 +189,8 @@ models:
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -171,6 +203,8 @@ models:
description: qwen-mt-plus多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -182,6 +216,8 @@ models:
description: qwen-mt-turbo轻量化多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -193,6 +229,8 @@ models:
description: qwen-plus-0112大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -205,6 +243,8 @@ models:
description: qwen-plus-0125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -217,6 +257,8 @@ models:
description: qwen-plus-0723大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -229,6 +271,8 @@ models:
description: qwen-plus-0806大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -241,6 +285,8 @@ models:
description: qwen-plus-0919大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -253,6 +299,8 @@ models:
description: qwen-plus-1125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -265,6 +313,8 @@ models:
description: qwen-plus-1127大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -277,6 +327,8 @@ models:
description: qwen-plus-1220大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -289,6 +341,10 @@ models:
description: qwen-vl-max多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -302,6 +358,10 @@ models:
description: qwen-vl-plus-0809多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -315,6 +375,10 @@ models:
description: qwen-vl-plus-2025-01-02多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -328,6 +392,10 @@ models:
description: qwen-vl-plus-2025-01-25多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -341,6 +409,10 @@ models:
description: qwen-vl-plus-latest多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -354,6 +426,10 @@ models:
description: qwen-vl-plus多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -367,6 +443,8 @@ models:
description: qwen2.5-0.5b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -379,6 +457,8 @@ models:
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -391,6 +471,8 @@ models:
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -403,6 +485,8 @@ models:
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -415,6 +499,8 @@ models:
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -427,6 +513,8 @@ models:
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -439,6 +527,8 @@ models:
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -451,6 +541,8 @@ models:
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -463,6 +555,8 @@ models:
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -475,6 +569,8 @@ models:
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -487,6 +583,8 @@ models:
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -498,6 +596,8 @@ models:
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -509,6 +609,8 @@ models:
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -520,6 +622,8 @@ models:
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -531,6 +635,8 @@ models:
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -544,6 +650,8 @@ models:
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -557,6 +665,8 @@ models:
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -569,6 +679,8 @@ models:
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -582,6 +694,8 @@ models:
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -594,6 +708,8 @@ models:
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -606,6 +722,11 @@ models:
description: qwen3-omni-flash-2025-12-01多模态大语言模型支持视觉、智能体思考、视频、音频能力65536上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
- audio
is_omni: true
tags:
- 大语言模型
- 多模态模型
@@ -620,6 +741,10 @@ models:
description: qwen3-vl-235b-a22b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -635,6 +760,10 @@ models:
description: qwen3-vl-235b-a22b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -650,6 +779,10 @@ models:
description: qwen3-vl-30b-a3b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -665,6 +798,10 @@ models:
description: qwen3-vl-30b-a3b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -680,6 +817,10 @@ models:
description: qwen3-vl-flash多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -695,6 +836,10 @@ models:
description: qwen3-vl-plus-2025-09-23多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -708,6 +853,10 @@ models:
description: qwen3-vl-plus多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -721,6 +870,8 @@ models:
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -732,6 +883,8 @@ models:
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -743,6 +896,8 @@ models:
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -754,6 +909,8 @@ models:
description: gte-rerank-v2重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -763,6 +920,8 @@ models:
description: gte-rerank重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -772,6 +931,9 @@ models:
description: multimodal-embedding-v1多模态嵌入模型支持视觉能力8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 嵌入模型
- 多模态模型
@@ -783,6 +945,8 @@ models:
description: text-embedding-v1文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -793,6 +957,8 @@ models:
description: text-embedding-v2文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -803,6 +969,8 @@ models:
description: text-embedding-v3文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -813,7 +981,9 @@ models:
description: text-embedding-v4文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
logo: dashscope

View File

@@ -6,7 +6,7 @@ from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
for model_data in models:
config_sync_fields = {
"logo": None,
"capability": None,
"is_omni": None,
"name": None,
"provider": None,
"type": None,
"description": None
}
try:
# 检查模型是否已存在
existing = db.query(ModelBase).filter(
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
# 更新现有模型配置
for key, value in model_data.items():
setattr(existing, key, value)
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
if sync_fields:
# 批量更新 ModelConfig
update_kwargs = {k: model_data[k] for k in sync_fields}
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
update_kwargs,
synchronize_session=False
)
# 更新 ModelApiKey 的 capability 和 is_omni
if 'capability' in model_data or 'is_omni' in model_data:
from app.models.models_model import ModelApiKey, model_config_api_key_association
api_key_update = {}
if 'capability' in model_data:
api_key_update['capability'] = model_data['capability']
if 'is_omni' in model_data:
api_key_update['is_omni'] = model_data['is_omni']
if api_key_update:
# 查找所有关联的 API Key
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
ModelConfig,
ModelConfig.id == model_config_api_key_association.c.model_config_id
).filter(ModelConfig.model_id == existing.id).distinct().all()
if api_key_ids:
api_key_ids = [aid[0] for aid in api_key_ids]
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
api_key_update,
synchronize_session=False
)
db.commit()
if not silent:
print(f"更新成功: {model_data['name']}")

View File

@@ -6,12 +6,19 @@ models:
description: chatgpt-4o-latest大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- audio
- video
is_omni: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- audio
- video
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
@@ -19,6 +26,8 @@ models:
description: gpt-3.5-turbo-0125大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -31,6 +40,8 @@ models:
description: gpt-3.5-turbo-1106大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -43,6 +54,8 @@ models:
description: gpt-3.5-turbo-16k大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -55,6 +68,8 @@ models:
description: gpt-3.5-turbo-instruct大语言模型4096上下文窗口文本补全模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: openai
@@ -64,6 +79,8 @@ models:
description: gpt-3.5-turbo大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -76,6 +93,8 @@ models:
description: gpt-4-0125-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +107,8 @@ models:
description: gpt-4-1106-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +121,9 @@ models:
description: gpt-4-turbo-2024-04-09大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -113,6 +137,8 @@ models:
description: gpt-4-turbo-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -125,6 +151,9 @@ models:
description: gpt-4-turbo大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -138,6 +167,8 @@ models:
description: o1-preview大语言模型支持智能体思考128000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -148,6 +179,9 @@ models:
description: o1大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -162,6 +196,9 @@ models:
description: o3-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -176,6 +213,8 @@ models:
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -189,6 +228,8 @@ models:
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -202,6 +243,9 @@ models:
description: o3-pro-2025-06-10大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -215,6 +259,9 @@ models:
description: o3-pro大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -228,6 +275,9 @@ models:
description: o3大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -242,6 +292,9 @@ models:
description: o4-mini-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -256,6 +309,9 @@ models:
description: o4-mini大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -270,6 +326,8 @@ models:
description: text-embedding-3-large文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -279,6 +337,8 @@ models:
description: text-embedding-3-small文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -288,6 +348,8 @@ models:
description: text-embedding-ada-002文本向量模型8097上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
logo: openai

View File

@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
from .chunk_summary import generate_chunk_summary
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
from .chunk_insight import generate_chunk_insight
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
__all__ = [
"generate_chunk_summary",
"extract_chunk_tags",
"extract_chunk_persona",
"generate_chunk_insight",
"generate_chunk_insight_sections",
]

View File

@@ -1,213 +1,207 @@
"""
Generate insights from RAG chunks.
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
This module provides functionality to analyze chunk content and generate insights using LLM.
The memory_insight.jinja2 template produces a four-section report:
【总体概述】 → memory_insight
【行为模式】 → behavior_pattern
【关键发现】 → key_findings
【成长轨迹】 → growth_trajectory
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
"""
import asyncio
import os
import re
from collections import Counter
from typing import Any, Dict, List
from typing import Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
# ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ChunkInsight(BaseModel):
"""Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析")
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
async def _classify_domain(chunk: str, llm_client) -> str:
"""Classify a single chunk into a domain category."""
from pydantic import BaseModel, Field
class DomainClassification(BaseModel):
"""Pydantic model for domain classification."""
domain: str = Field(
...,
description="内容所属的领域分类",
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
)
class _Domain(BaseModel):
domain: str = Field(..., description="领域分类")
async def classify_chunk_domain(chunk: str) -> str:
"""
Classify a chunk into a specific domain.
Args:
chunk: Chunk content string
Returns:
Domain name
"""
try:
llm_client = _get_llm_client()
prompt = f"""请将以下文本内容归类到最合适的领域中。
可选领域及其关键词:
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
- 健康:医疗、养生、运动、心理、保健、疾病等
- 其他:无法归入以上类别的内容
文本内容: {chunk[:500]}...
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
{"role": "user", "content": prompt}
]
classification = await llm_client.response_structured(
messages=messages,
response_model=DomainClassification
prompt = (
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
)
return classification.domain if classification else "其他"
except Exception as e:
business_logger.error(f"分类chunk领域失败: {str(e)}")
result = await llm_client.response_structured(
messages=[{"role": "user", "content": prompt}],
response_model=_Domain,
)
return result.domain if result else "其他"
except Exception:
return "其他"
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
async def _build_insight_inputs(
chunks: List[str],
max_chunks: int,
end_user_id: Optional[str],
) -> Dict[str, Optional[str]]:
"""
Analyze the domain distribution of chunks.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to analyze
Returns:
Dictionary of domain -> percentage
Derive domain_distribution, active_periods, social_connections strings
to feed into the memory_insight.jinja2 template.
"""
if not chunks:
return {}
try:
# 限制分析的chunk数量
chunks_to_analyze = chunks[:max_chunks]
# 为每个chunk分类
domain_counts = Counter()
for chunk in chunks_to_analyze:
domain = await classify_chunk_domain(chunk)
domain_counts[domain] += 1
# 计算百分比
total = sum(domain_counts.values())
domain_distribution = {
domain: count / total
for domain, count in domain_counts.items()
}
# 按百分比降序排序
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
except Exception as e:
business_logger.error(f"分析领域分布失败: {str(e)}")
return {}
llm_client = _get_llm_client(end_user_id)
chunks_sample = chunks[:max_chunks]
# Domain distribution
domain_counts: Counter = Counter()
for chunk in chunks_sample:
domain = await _classify_domain(chunk, llm_client)
domain_counts[domain] += 1
total = sum(domain_counts.values()) or 1
domain_distribution = ", ".join(
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
)
return {
"domain_distribution": domain_distribution,
"active_periods": None, # RAG模式暂无时间维度数据
"social_connections": None, # RAG模式暂无社交关联数据
}
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
# ── Section parser ────────────────────────────────────────────────────────────
_ZH_SECTIONS = {
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
"key_findings": r"【关键发现】(.*?)(?=【|$)",
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
}
_EN_SECTIONS = {
"memory_insight": r"【Overview】(.*?)(?=【|$)",
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
}
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
"""Extract the four sections from the LLM output."""
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
result = {}
for key, pattern in patterns.items():
match = re.search(pattern, text, re.DOTALL)
result[key] = match.group(1).strip() if match else ""
return result
# ── Public API ────────────────────────────────────────────────────────────────
async def generate_chunk_insight(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
"""
Generate insights from the given chunks.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to analyze
Returns:
A comprehensive insight report
Generate a memory insight report from RAG chunks.
Returns the full raw report text (suitable for end_user.memory_insight).
Use generate_chunk_insight_sections() when you need all four dimensions.
"""
sections = await generate_chunk_insight_sections(
chunks=chunks,
max_chunks=max_chunks,
end_user_id=end_user_id,
language=language,
)
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
async def generate_chunk_insight_sections(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> Dict[str, str]:
"""
Generate a four-section memory insight report from RAG chunks.
Returns a dict with keys:
memory_insight, behavior_pattern, key_findings, growth_trajectory
(plus '_raw' containing the full LLM output for debugging)
"""
if not chunks:
business_logger.warning("没有提供chunk内容用于生成洞察")
return "暂无足够数据生成洞察报告"
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "暂无足够数据生成洞察报告"
return empty
try:
# 1. 分析领域分布
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
# 2. 统计基本信息
total_chunks = len(chunks)
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
# 3. 构建洞察prompt
prompt_parts = []
if domain_dist:
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
prompt_parts.append(f"- 内容领域分布: {top_domains}")
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}")
# 添加部分chunk内容作为参考
sample_chunks = chunks[:5]
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
prompt_parts.append(f"\n内容示例:\n{sample_content}")
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
重要规则:
1. 报告需要将所有要点流畅地串联成一个段落
2. 语言风格要专业、客观,同时易于理解
3. 不要添加任何额外的解释或标题,直接输出报告内容
4. 基于提供的数据和示例内容进行分析,不要编造信息
5. 重点关注内容的主题、特点和价值
6. 报告长度控制在150-200字
# Build template inputs from chunk analysis
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
例如,如果输入是:
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
- 内容规模: 共50个知识片段平均长度320字
内容示例: [示例内容...]
rendered_prompt = await render_memory_insight_prompt(
domain_distribution=inputs["domain_distribution"],
active_periods=inputs["active_periods"],
social_connections=inputs["social_connections"],
language=language,
)
你的输出应该类似:
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段平均每个片段约320字内容详实。从示例来看内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
"""
user_prompt = "\n".join(prompt_parts)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 调用LLM生成洞察
llm_client = _get_llm_client()
messages = [{"role": "user", "content": rendered_prompt}]
llm_client = _get_llm_client(end_user_id)
response = await llm_client.chat(messages=messages)
insight = response.content.strip()
business_logger.info(f"成功生成chunk洞察分析了 {min(len(chunks), max_chunks)} 个片段")
return insight
raw_text = response.content.strip() if response and response.content else ""
sections = _parse_sections(raw_text, language=language)
sections["_raw"] = raw_text
business_logger.info(
f"成功生成chunk洞察四维度分析了 {min(len(chunks), max_chunks)} 个片段"
)
return sections
except Exception as e:
business_logger.error(f"生成chunk洞察失败: {str(e)}")
return "洞察生成失败"
if __name__ == "__main__":
# 测试代码
test_chunks = [
"Python是一种高级编程语言以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
]
print("开始生成chunk洞察...")
insight = asyncio.run(generate_chunk_insight(test_chunks))
print(f"\n生成的洞察:\n{insight}")
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "洞察生成失败"
return empty

View File

@@ -1,11 +1,10 @@
"""
Generate summary for RAG chunks.
This module provides functionality to summarize chunk content using LLM.
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
"""
import asyncio
from typing import Any, Dict, List
import os
from typing import List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
class ChunkSummary(BaseModel):
"""Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要")
# ── Schema ──────────────────────────────────────────────────────────────────
class MemorySummaryStatement(BaseModel):
"""Single labelled statement extracted by memory_summary.jinja2."""
statement: str = Field(..., description="提取的陈述内容")
label: Optional[str] = Field(None, description="陈述标签")
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
class MemorySummaryResponse(BaseModel):
"""
Generate a summary for the given chunks.
Structured output expected from memory_summary.jinja2.
The template asks for a JSON array of labelled statements;
we wrap it in an object so response_structured can parse it.
"""
statements: List[MemorySummaryStatement] = Field(
default_factory=list,
description="从chunk中提取的陈述列表"
)
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
# ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
# ── Core function ─────────────────────────────────────────────────────────────
async def generate_chunk_summary(
chunks: List[str],
max_chunks: int = 10,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
"""
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
The template extracts labelled statements from the chunks; we then join them
into a coherent summary string that can be stored in end_user.user_summary.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to process (default: 10)
max_chunks: Maximum number of chunks to process
end_user_id: Optional end-user ID for model selection
language: Output language ("zh" or "en")
Returns:
A concise summary of the chunks
Summary string (joined statements or fallback text)
"""
if not chunks:
business_logger.warning("没有提供chunk内容用于生成摘要")
return "暂无内容"
try:
# 限制处理的chunk数量避免token过多
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
chunks_to_process = chunks[:max_chunks]
# 合并chunk内容
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
# 构建prompt
system_prompt = (
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
"- 摘要长度控制在100-150字\n"
"- 提取核心信息和关键要点;\n"
"- 使用客观、清晰的语言;\n"
"- 避免冗余和重复;\n"
"- 如果内容涉及多个主题,按重要性排序呈现。"
chunk_texts = "\n\n".join(
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
)
json_schema = MemorySummaryResponse.model_json_schema()
rendered_prompt = await render_memory_summary_prompt(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=200,
language=language,
)
messages = [{"role": "user", "content": rendered_prompt}]
llm_client = _get_llm_client(end_user_id)
# Try structured output; fall back to plain chat only for LLMClientException
# (indicates the model/provider doesn't support structured output).
# All other exceptions are re-raised so config/schema errors stay visible.
try:
response: MemorySummaryResponse = await llm_client.response_structured(
messages=messages,
response_model=MemorySummaryResponse,
)
if response.summary:
summary = response.summary.strip()
elif response.statements:
summary = "".join(s.statement for s in response.statements)
else:
summary = "暂无内容"
except Exception as e:
from app.core.memory.llm_tools.llm_client import LLMClientException
if isinstance(e, LLMClientException):
business_logger.warning(
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
)
raw = await llm_client.chat(messages=messages)
summary = raw.content.strip() if raw and raw.content else "暂无内容"
else:
business_logger.error(f"生成摘要时发生非预期异常: {e}")
raise
business_logger.info(
f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段"
)
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# 调用LLM生成摘要
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
summary = response.content.strip()
business_logger.info(f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段")
return summary
except Exception as e:
business_logger.error(f"生成chunk摘要失败: {str(e)}")
return "摘要生成失败"
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
"""
Generate summaries for multiple chunk lists in batch.
Args:
chunks_list: List of chunk lists
Returns:
List of summaries
"""
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
return await asyncio.gather(*tasks)
if __name__ == "__main__":
# 测试代码
test_chunks = [
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
"第二段内容介绍了深度学习的应用场景和发展历史。",
"第三段讨论了自然语言处理技术的最新进展。"
]
print("开始生成chunk摘要...")
summary = asyncio.run(generate_chunk_summary(test_chunks))
print(f"\n生成的摘要:\n{summary}")

View File

@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
"""
import asyncio
import os
from collections import Counter
from typing import List, Tuple
from typing import List, Optional, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ExtractedTags(BaseModel):
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师''旅行爱好者'")
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
"""
Extract meaningful tags from the given chunks.
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
)
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
# 为每个chunk单独提取标签然后统计频率
all_tags = []
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
"""
Extract persona (人物形象) from the given chunks.
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
]
# 调用LLM提取人物形象
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
structured_response = await llm_client.response_structured(
messages=messages,
response_model=ExtractedPersona

View File

@@ -85,6 +85,7 @@ class StorageFactory:
access_key_id=settings.S3_ACCESS_KEY_ID,
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
bucket_name=settings.S3_BUCKET_NAME,
endpoint_url=settings.S3_ENDPOINT_URL,
)
else:

View File

@@ -35,6 +35,19 @@ class S3Storage(StorageBackend):
bucket_name: The name of the S3 bucket.
region: The AWS region.
"""
AMAZON_S3_ENDPOINT_MAP = {
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
"us-east-2": "https://s3.us-east-2.amazonaws.com",
"us-west-1": "https://s3.us-west-1.amazonaws.com",
"us-west-2": "https://s3.us-west-2.amazonaws.com",
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
# 可根据需要扩展其他地域
}
def __init__(
self,
@@ -42,6 +55,7 @@ class S3Storage(StorageBackend):
access_key_id: str,
secret_access_key: str,
bucket_name: str,
endpoint_url: Optional[str] = None
):
"""
Initialize the S3Storage backend.
@@ -51,6 +65,7 @@ class S3Storage(StorageBackend):
access_key_id: The AWS access key ID.
secret_access_key: The AWS secret access key.
bucket_name: The name of the S3 bucket.
endpoint_url: The complete URL to use for the constructed client.
Raises:
StorageConfigError: If any required configuration is missing.
@@ -69,10 +84,19 @@ class S3Storage(StorageBackend):
self.region = region
self.bucket_name = bucket_name
if not endpoint_url:
# 优先匹配内置映射表(解决特殊地域)
if region in self.AMAZON_S3_ENDPOINT_MAP:
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
# 兜底:通用拼接(适配未配置的新地域)
else:
endpoint_url = f"https://s3.{region}.amazonaws.com"
try:
self.client = boto3.client(
"s3",
region_name=region,
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
)

View File

@@ -53,6 +53,7 @@ class SimpleMCPClient:
else:
await self._connect_http()
except Exception as e:
await self.disconnect()
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}")

View File

@@ -0,0 +1,8 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 15:54
from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
__all__ = ["DifyAdapter", "MemoryBearAdapter"]

View File

@@ -0,0 +1,90 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 15:58
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefineition
from app.schemas.workflow_schema import (
EdgeDefinition,
NodeDefinition,
VariableDefinition,
ExecutionConfig,
TriggerConfig
)
class PlatformType(StrEnum):
MEMORY_BEAR = "memory_bear"
DIFY = "dify"
COZE = "coze"
class PlatformMetadata(BaseModel):
platform_name: str
version: str
support_node_types: list[str]
class WorkflowParserResult(BaseModel):
success: bool
platform: PlatformMetadata
execution_config: ExecutionConfig
origin_config: dict[str, Any]
trigger: TriggerConfig | None
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel):
success: bool
temp_id: str | None = Field(..., description="cache id")
workflow_id: str | None = Field(..., description="workflow id")
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list)
class BasePlatformAdapter(ABC):
def __init__(self, config: dict[str, Any]):
self.config = config
self.nodes: list[NodeDefinition] = []
self.edges: list[EdgeDefinition] = []
self.conv_variables: list[VariableDefinition] = []
self.errors = []
self.warnings = []
self.branch_node_cache = defaultdict(list)
self.error_branch_node_cache = []
self.node_output_map = {}
@abstractmethod
def get_metadata(self) -> PlatformMetadata:
"""get platform metadata"""
pass
@abstractmethod
def validate_config(self) -> bool:
"""platform configuration validate"""
pass
@abstractmethod
def parse_workflow(self) -> WorkflowParserResult:
"""parse platform configuration to local config"""
pass
@abstractmethod
def map_node_type(self, platform_node_type: str) -> str:
pass

View File

@@ -0,0 +1,75 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 14:32
from abc import ABC, abstractmethod
from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType
class BaseConverter(ABC):
@staticmethod
def _convert_string(var):
try:
return str(var)
except:
return DEFAULT_VALUE(VariableType.STRING)
@staticmethod
def _convert_boolean(var):
try:
return bool(var)
except:
return DEFAULT_VALUE(VariableType.BOOLEAN)
@staticmethod
def _convert_number(var):
try:
return float(var)
except:
return DEFAULT_VALUE(VariableType.NUMBER)
@staticmethod
def _convert_object(var):
try:
return dict(var)
except:
return DEFAULT_VALUE(VariableType.OBJECT)
@staticmethod
@abstractmethod
def _convert_file(var):
pass
@staticmethod
def _convert_array_string(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_STRING)
@staticmethod
def _convert_array_number(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_NUMBER)
@staticmethod
def _convert_array_boolean(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN)
@staticmethod
def _convert_array_object(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_OBJECT)
@staticmethod
@abstractmethod
def _convert_array_file(var):
pass

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 18:20

View File

@@ -0,0 +1,773 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 18:21
import base64
import re
from typing import Any
from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import (
UnsupportVariableType,
UnknowModelWarning,
ExceptionDefineition,
ExceptionType
)
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import (
StartNodeConfig,
LLMNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
EndNodeConfig,
HttpRequestNodeConfig,
IfElseNodeConfig,
JinjaRenderNodeConfig,
KnowledgeRetrievalNodeConfig,
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig
)
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable
)
from app.core.workflow.nodes.enums import (
ValueInputType,
ComparisonOperator,
AssignmentOperator,
HttpAuthType,
HttpContentType,
HttpErrorHandle,
NodeType
)
from app.core.workflow.nodes.http_request.config import (
HttpAuthConfig,
HttpContentTypeConfig,
HttpFormData,
HttpTimeOutConfig,
HttpRetryConfig,
HttpErrorDefaultTamplete,
HttpErrorHandleConfig
)
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
class DifyConverter(BaseConverter):
errors: list
warnings: list
branch_node_cache: dict
error_branch_node_cache: list
node_output_map: dict
def __init__(self):
self.CONFIG_CONVERT_MAP = {
NodeType.START: self.convert_start_node_config,
NodeType.LLM: self.convert_llm_node_config,
NodeType.END: self.convert_end_node_config,
NodeType.IF_ELSE: self.convert_if_else_node_config,
NodeType.LOOP: self.convert_loop_node_config,
NodeType.ITERATION: self.convert_iteration_node_config,
NodeType.ASSIGNER: self.convert_assigner_node_config,
NodeType.CODE: self.convert_code_node_config,
NodeType.HTTP_REQUEST: self.convert_http_node_config,
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
NodeType.TOOL: self.convert_tool_node_config,
NodeType.NOTES: self.convert_notes_config,
NodeType.CYCLE_START: lambda x: {},
NodeType.BREAK: lambda x: {},
}
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
return func
def config_validate(
self,
node_id: str,
node_name: str,
config: type[BaseNodeConfig],
value: dict
):
try:
return config.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
@staticmethod
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
def process_var_selector(self, var_selector):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
} | self.node_output_map
var_selector = mapping.get(var_selector, var_selector)
return var_selector
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)):
return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
def trans_variable_format(self, content):
pattern = re.compile(r"\{\{#(.*?)#}}")
def replacer(match: re.Match) -> str:
raw_name = match.group(1)
new_name = self.process_var_selector(raw_name)
return f"{{{{{new_name}}}}}"
return pattern.sub(replacer, content)
@staticmethod
def _convert_file(var):
return None
@staticmethod
def _convert_array_file(var):
return []
@staticmethod
def variable_type_map(source_type) -> VariableType | None:
type_map = {
"file": VariableType.FILE,
"paragraph": VariableType.STRING,
"text-input": VariableType.STRING,
"number": VariableType.NUMBER,
"checkbox": VariableType.BOOLEAN,
"file-list": VariableType.ARRAY_FILE,
"select": VariableType.STRING,
"integer": VariableType.NUMBER,
"float": VariableType.NUMBER,
}
var_type = type_map.get(source_type, source_type)
return var_type
def convert_variable_type(self, target_type: VariableType, origin_value: Any):
if not origin_value:
return DEFAULT_VALUE(target_type)
try:
match target_type:
case VariableType.STRING:
return self._convert_string(origin_value)
case VariableType.NUMBER:
return self._convert_number(origin_value)
case VariableType.BOOLEAN:
return self._convert_boolean(origin_value)
case VariableType.FILE:
return self._convert_file(origin_value)
case VariableType.ARRAY_FILE:
return self._convert_array_file(origin_value)
case _:
return origin_value
except:
raise Exception(f"convert variable failed: {target_type}")
@staticmethod
def convert_compare_operator(operator):
operator_map = {
"is": ComparisonOperator.EQ,
"is not": ComparisonOperator.NE,
"=": ComparisonOperator.EQ,
"": ComparisonOperator.NE,
">": ComparisonOperator.GT,
"<": ComparisonOperator.LT,
"": ComparisonOperator.GE,
"": ComparisonOperator.LE,
"not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY
}
return operator_map.get(operator, operator)
@staticmethod
def convert_assignment_operator(operator):
operator_map = {
"+=": AssignmentOperator.ADD,
"-=": AssignmentOperator.SUBTRACT,
"*=": AssignmentOperator.MULTIPLY,
"/=": AssignmentOperator.DIVIDE,
"over-write": AssignmentOperator.COVER,
"remove-last": AssignmentOperator.REMOVE_LAST,
"remove-first": AssignmentOperator.REMOVE_FIRST,
"set": AssignmentOperator.ASSIGN,
}
return operator_map.get(operator, operator)
@staticmethod
def convert_http_auth_type(auth_type):
auth_type_map = {
"no-auth": HttpAuthType.NONE,
"bearer": HttpAuthType.BEARER,
"basic": HttpAuthType.BASIC,
"custom": HttpAuthType.CUSTOM,
}
return auth_type_map.get(auth_type, auth_type)
@staticmethod
def convert_http_content_type(content_type):
content_type_map = {
"none": HttpContentType.NONE,
"form-data": HttpContentType.FROM_DATA,
"x-www-form-urlencoded": HttpContentType.WWW_FORM,
"json": HttpContentType.JSON,
"raw-text": HttpContentType.RAW,
"binary": HttpContentType.BINARY,
}
return content_type_map.get(content_type, content_type)
@staticmethod
def convert_http_error_handle_type(handle_type):
handle_type_map = {
"none": HttpErrorHandle.NONE,
"fail-branch": HttpErrorHandle.BRANCH,
"default-value": HttpErrorHandle.DEFAULT,
}
return handle_type_map.get(handle_type, handle_type)
def convert_start_node_config(self, node: dict) -> dict:
node_data = node["data"]
start_vars = []
for var in node_data["variables"]:
var_type = self.variable_type_map(var["type"])
if not var_type:
self.errors.append(
UnsupportVariableType(
scope=node["id"],
name=var["variable"],
var_type=var["type"],
node_id=node["id"],
node_name=node_data["title"]
)
)
continue
if var_type in ["file", "array[file]"]:
self.errors.append(
ExceptionDefineition(
type=ExceptionType.VARIABLE,
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"Unsupported Variable type for start node: {var_type}"
)
)
continue
var_def = VariableDefinition(
name=var["variable"],
type=var_type,
required=var["required"],
default=self.convert_variable_type(
var_type, var.get("default")
),
description=var["label"],
max_length=var.get("max_length", 50),
)
start_vars.append(var_def)
result = StartNodeConfig.model_construct(
variables=start_vars
).model_dump()
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
return result
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
categories = []
for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"])
categories.append(
ClassifierConfig.model_construct(
class_name=category["name"],
)
)
result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
return result
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20))
)
messages = []
for message in node_data["prompt_template"]:
messages.append(
MessageConfig(
role=message["role"],
content=self.trans_variable_format(message["text"])
)
)
if memory.enable:
messages.append(
MessageConfig(
role="user",
content=self.trans_variable_format(
node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}"
)
)
)
vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
result = LLMNodeConfig.model_construct(
model_id=None,
context=context,
memory=memory,
vision=vision,
vision_input=vision_input,
messages=messages
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
return result
def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"]
result = EndNodeConfig.model_construct(
output=self.trans_variable_format(node_data.get("answer", "")),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
return result
def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"]
cases = []
for case in node_data["cases"]:
case_id = case.get("id") or case.get("case_id")
logical_operator = case["logical_operator"]
conditions = []
for condition in case["conditions"]:
right_value = condition["value"]
condition_detail = ConditionDetail(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}",
right=self.trans_variable_format(
right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
self.variable_type_map(condition["varType"]),
condition["value"]
),
input_type=ValueInputType.VARIABLE
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
conditions.append(condition_detail)
cases.append(
ConditionBranchConfig(
logical_operator=logical_operator,
expressions=conditions
)
)
self.branch_node_cache[node["id"]].append(case_id)
result = IfElseNodeConfig.model_construct(
cases=cases
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
return result
def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"]
logical_operator = node_data["logical_operator"]
conditions = []
for condition in node_data["break_conditions"]:
right_value = condition["value"]
conditions.append(
LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format(
right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
self.variable_type_map(condition["varType"]),
condition["value"]
),
input_type=ValueInputType.VARIABLE
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
)
condition_config = ConditionsConfig.model_construct(
logical_operator=logical_operator,
expressions=conditions
)
loop_variables = []
for variable in node_data["loop_variables"]:
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable.get("value", ""))
else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append(
CycleVariable(
name=variable["label"],
type=right_value_type,
value=right_value,
input_type=right_input_type
)
)
result = LoopNodeConfig.model_construct(
condition=condition_config,
cycle_vars=loop_variables,
max_loop=node_data.get("loop_count", 10)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
return result
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
return result
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
for assignment in node_data["items"]:
if assignment.get("operation") is None or assignment.get("value") is None:
continue
assignments.append(
AssignmentItem(
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
value=self._process_list_variable_litearl(
assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"])
)
)
result = AssignerNodeConfig.model_construct(
assignments=assignments
).model_dump()
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
return result
def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"]
input_variables = []
for input_variable in node_data["variables"]:
input_variables.append(
InputVariable.model_construct(
name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
)
)
output_variables = []
for output_variable in node_data["outputs"]:
output_variables.append(
OutputVariable.model_construct(
name=output_variable,
type=node_data["outputs"][output_variable]["type"],
)
)
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
result = CodeNodeConfig.model_construct(
input_variables=input_variables,
language=node_data["code_language"],
output_variables=output_variables,
code=code
).model_dump()
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
return result
def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"]
if node_data["authorization"]["type"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig.model_construct(
auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"),
)
else:
auth_config = HttpAuthConfig()
content_type = self.convert_http_content_type(node_data["body"]["type"])
if content_type == HttpContentType.FROM_DATA:
body_content = []
for content in node_data["body"]["data"]:
body_content.append(
HttpFormData(
key=self.trans_variable_format(content["key"]),
type=content["type"],
value=self.trans_variable_format(content["value"]),
)
)
elif content_type == HttpContentType.WWW_FORM:
body_content = {}
for content in node_data["body"]["data"]:
body_content[
self.trans_variable_format(content["key"])
] = self.trans_variable_format(content["value"])
else:
if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
else:
body_content = ""
headers = {}
for header in node_data.get("headers", "").split("\n"):
if not header:
continue
key_value = header.split(":")
if len(key_value) == 2:
headers[
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
detail=f"Invalid header/param - {header}",
))
params = {}
for param in node_data.get("params", "").split("\n"):
if not param:
continue
key_value = param.split(":")
if len(key_value) == 2:
params[
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
detail=f"Invalid header/param - {param}",
))
error_handle_type = self.convert_http_error_handle_type(
node_data.get("error_strategy", "none")
)
default_value = None
if error_handle_type == HttpErrorHandle.DEFAULT:
default_body = ""
default_header = {}
default_status_code = 0
for var in node_data.get("default_value") or []:
if var["key"] == "body":
default_body = var["value"]
elif var["key"] == "header":
default_header = var["value"]
elif var["key"] == "status_code":
default_status_code = var["value"]
default_value = HttpErrorDefaultTamplete(
body=default_body,
headers=default_header,
status_code=default_status_code,
)
self.error_branch_node_cache.append(node['id'])
result = HttpRequestNodeConfig.model_construct(
method=node_data["method"].upper(),
url=node_data["url"],
auth=auth_config,
body=HttpContentTypeConfig.model_construct(
content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content,
),
headers=headers,
params=params,
verify_ssl=node_data.get("ssl_verify", False),
timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
),
retry=HttpRetryConfig.model_construct(
enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"],
),
error_handle=HttpErrorHandleConfig.model_construct(
method=error_handle_type,
default=default_value,
)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
return result
def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"]
mapping = []
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"])
))
result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"],
mapping=mapping,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
return result
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
return result
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
params = []
for param in node_data.get("parameters", []):
params.append(
ParamsConfig.model_construct(
name=param["name"],
desc=param["description"],
required=param["required"],
type=param["type"],
)
)
result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]),
params=params,
prompt=node_data.get("instruction")
).model_dump()
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
return result
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"]
advanced_settings = node_data.get("advanced_settings", {})
group_variables = {}
group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [
self._process_list_variable_litearl(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
result = VariableAggregatorNodeConfig.model_construct(
group=advanced_settings.get("group_enabled", False),
group_variables=group_variables,
group_type=group_type,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
return result
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.",
))
return {}
@staticmethod
def convert_notes_config(node: dict):
node_data = node["data"]
result = NoteNodeConfig.model_construct(
author=node_data.get("author", ""),
text=node_data.get("text", ""),
width=node_data.get("width", 80),
height=node_data.get("height", 80),
theme=node_data.get("theme", "blue"),
show_author=node_data.get("showAuthor", True)
).model_dump()
return result

View File

@@ -0,0 +1,259 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 16:05
from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter,
PlatformMetadata,
PlatformType,
WorkflowParserResult
)
from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import (
NodeDefinition,
EdgeDefinition,
VariableDefinition,
TriggerConfig,
ExecutionConfig
)
logger = get_logger()
class DifyAdapter(BasePlatformAdapter, DifyConverter):
NODE_TYPE_MAPPING = {
"start": NodeType.START,
"llm": NodeType.LLM,
"answer": NodeType.END,
"if-else": NodeType.IF_ELSE,
"loop-start": NodeType.CYCLE_START,
"iteration-start": NodeType.CYCLE_START,
"assigner": NodeType.ASSIGNER,
"loop": NodeType.LOOP,
"iteration": NodeType.ITERATION,
"loop-end": NodeType.BREAK,
"code": NodeType.CODE,
"http-request": NodeType.HTTP_REQUEST,
"template-transform": NodeType.JINJARENDER,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL,
"": NodeType.NOTES
}
def __init__(self, config: dict[str, Any]):
DifyConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.DIFY,
version="0.5.0",
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
)
def map_node_type(self, platform_node_type) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property
def origin_nodes(self):
return self.config.get("workflow").get("graph").get("nodes")
@property
def origin_edges(self):
return self.config.get("workflow").get("graph").get("edges")
@staticmethod
def _valid_nodes(node: dict[str, Any]):
if "data" not in node:
return False
if "type" not in node["data"]:
return False
if "id" not in node or "type" not in node:
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition(
type=ExceptionType.PLATFORM,
detail="workflow mode is not supported"
))
return False
for node in self.origin_nodes:
if not self._valid_nodes(node):
return False
return True
def parse_workflow(self) -> WorkflowParserResult:
self._init_node_output_map()
for node in self.origin_nodes:
node = self._convert_node(node)
if node:
self.nodes.append(node)
nodes_id = [node.id for node in self.nodes]
for edge in self.origin_edges:
source = edge["source"]
target = edge["target"]
if source not in nodes_id or target not in nodes_id:
continue
edge = self._convert_edge(edge)
if edge:
self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable)
if variable:
self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables)
# conv_variables.append(variable)
trigger = self._convert_trigger({})
execution_config = self._convert_execution({})
return WorkflowParserResult(
success=not self.errors and not self.warnings,
platform=self.get_metadata(),
execution_config=execution_config,
origin_config=self.config,
trigger=trigger,
edges=self.edges,
nodes=self.nodes,
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors
)
def _init_node_output_map(self):
for node in self.origin_nodes:
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL:
self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output"
def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes:
if node["id"] == node_id:
return {
"x": node["position"]["x"] + position["x"],
"y": node["position"]["y"] + position["y"]
}
self.errors.append(
ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
detail="parent cycle node not found"
)
)
raise Exception("parent cycle node not found")
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_data = node["data"]
try:
node_type = self.map_node_type(node_data["type"])
return NodeDefinition(
id=node["id"],
type=node_type,
name=node_data.get("title") or "notes",
cycle=node.get("parentId"),
description=None,
config=self._convert_node_config(node_type, node),
position={
"x": node["position"]["x"],
"y": node["position"]["y"]
} if node.get("parentId") is None else self._convert_cycle_node_position(
node["parentId"],
node["position"]
),
error_handling=None,
cache=None
)
except Exception as e:
logger.debug(f"convert node error - {e}", exc_info=True)
def _convert_node_config(self, node_type: NodeType, node: dict):
try:
node_data = node["data"]
converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"node type {node_data.get('type')} is unsupported",
))
return converter(node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"convert node error - {e}",
))
raise e
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try:
source = edge["source"]
target = edge["target"]
label = None
if source in self.branch_node_cache:
case_id = edge["sourceHandle"]
if case_id == "false":
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
else:
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
if source in self.error_branch_node_cache:
case_id = edge["sourceHandle"]
if case_id == "source":
label = "SUCCESS"
else:
label = "ERROR"
return EdgeDefinition(
id=edge["id"],
source=source,
target=target,
label=label,
)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}",
))
logger.debug(f"convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable) -> VariableDefinition | None:
try:
return VariableDefinition(
name=variable["name"],
default=variable["value"],
type=self.variable_type_map(variable["value_type"]),
description=variable.get("description")
)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}",
))
def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None:
pass
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
return ExecutionConfig()

View File

@@ -0,0 +1,75 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 11:29
from enum import StrEnum
from pydantic import BaseModel
class ExceptionType(StrEnum):
NODE = "node"
EDGE = "edge"
VARIABLE = "variable"
TRIGGER = "trigger"
EXECUTION = "execution"
CONFIG = "config"
PLATFORM = "platform"
UNKNOWN = "unknown"
class ExceptionDefineition(BaseModel):
type: ExceptionType
detail: str
node_id: str | None = None
node_name: str | None = None
scope: str | None = None
name: str | None = None
class UnknowModelWarning(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name):
super().__init__(
detail=f"Please specify the model mapping manually for model: {model_name}",
node_id=node_id,
node_name=node_name
)
class UnknowError(ExceptionDefineition):
type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs)
class UnsupportPlatform(ExceptionDefineition):
type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str):
super().__init__(detail=f"Unsupport platform {platform}")
class UnsupportVariableType(ExceptionDefineition):
type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefineition):
type: ExceptionType = ExceptionType.CONFIG
def __init__(self):
super().__init__(detail="Invalid workflow configuration format")
class UnsupportNodeType(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 11:30

View File

@@ -0,0 +1,155 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 14:11
from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import (
PlatformMetadata,
PlatformType,
BasePlatformAdapter,
WorkflowParserResult
)
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
logger = get_logger()
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
def __init__(self, config: dict[str, Any]):
MemoryBearConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
@property
def origin_nodes(self):
return self.config.get("workflow").get("nodes") or []
@property
def origin_edges(self):
return self.config.get("workflow").get("edges") or []
@property
def origin_variables(self):
return self.config.get("workflow").get("variables") or []
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.MEMORY_BEAR,
version="0.2.5",
support_node_types=list(VALID_NODE_TYPES)
)
def map_node_type(self, platform_node_type: str) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@staticmethod
def _valid_node(node: dict[str, Any]) -> bool:
if "id" not in node or "type" not in node:
return False
if not isinstance(node.get("config"), dict):
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
for node in self.origin_nodes:
if not self._valid_node(node):
return False
return True
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_id = node.get("id")
node_name = node.get("name")
try:
node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType(
node_id=node_id,
node_type=node["type"]
))
return None
config = node.get("config") or {}
converter = self.get_node_convert(node_type)
converter(node_id, node_name, config) # validates and appends errors if invalid
return NodeDefinition(**node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
node_name=node_name,
detail=f"convert node error - {e}"
))
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
return None
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found"
))
return None
return EdgeDefinition(**edge)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}"
))
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
try:
return VariableDefinition(**variable)
except Exception as e:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}"
))
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
return None
def parse_workflow(self) -> WorkflowParserResult:
for node in self.origin_nodes:
converted = self._convert_node(node)
if converted:
self.nodes.append(converted)
valid_node_ids = {n.id for n in self.nodes}
for edge in self.origin_edges:
converted = self._convert_edge(edge, valid_node_ids)
if converted:
self.edges.append(converted)
for variable in self.origin_variables:
converted = self._convert_variable(variable)
if converted:
self.conv_variables.append(converted)
return WorkflowParserResult(
success=not self.errors and not self.warnings,
platform=self.get_metadata(),
execution_config=ExecutionConfig(),
origin_config=self.config,
trigger=None,
edges=self.edges,
nodes=self.nodes,
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors,
)

View File

@@ -0,0 +1,85 @@
# -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import (
StartNodeConfig,
EndNodeConfig,
LLMNodeConfig,
AgentNodeConfig,
IfElseNodeConfig,
KnowledgeRetrievalNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
HttpRequestNodeConfig,
JinjaRenderNodeConfig,
VariableAggregatorNodeConfig,
ParameterExtractorNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
QuestionClassifierNodeConfig,
ToolNodeConfig,
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
class MemoryBearConverter(BaseConverter):
errors: list
warnings: list
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
NodeType.START: StartNodeConfig,
NodeType.END: EndNodeConfig,
NodeType.ANSWER: EndNodeConfig,
NodeType.LLM: LLMNodeConfig,
NodeType.AGENT: AgentNodeConfig,
NodeType.IF_ELSE: IfElseNodeConfig,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
NodeType.ASSIGNER: AssignerNodeConfig,
NodeType.CODE: CodeNodeConfig,
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
NodeType.JINJARENDER: JinjaRenderNodeConfig,
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
NodeType.LOOP: LoopNodeConfig,
NodeType.ITERATION: IterationNodeConfig,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
NodeType.TOOL: ToolNodeConfig,
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
}
@staticmethod
def _convert_file(var):
return None
@staticmethod
def _convert_array_file(var):
return []
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
try:
return config_cls.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
def get_node_convert(self, node_type: NodeType):
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
if not config_cls:
return lambda node_id, node_name, config: config
def validate(node_id: str, node_name: str, config: dict):
self.config_validate(node_id, node_name, config_cls, config)
return config
return validate

View File

@@ -0,0 +1,34 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 14:19
from typing import Any
from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter
from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType
class PlatformAdapterRegistry:
_adapters: dict[str, type[BasePlatformAdapter]] = {}
@classmethod
def register(cls, platform: str, adapter: type[BasePlatformAdapter]):
cls._adapters[platform] = adapter
@classmethod
def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter:
if platform not in cls._adapters:
raise ValueError(f"Unsupported platform: {platform}")
return cls._adapters.get(platform)(config)
@classmethod
def list_platforms(cls) -> list[str]:
return list(cls._adapters.keys())
@classmethod
def is_supported(cls, platform: str) -> bool:
return platform in cls._adapters
PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter)
PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter)

View File

@@ -127,7 +127,7 @@ class EventStreamHandler:
yield {
"event": "message",
"data": {
"chunk": data.get("chunk")
"content": data.get("chunk")
}
}

View File

@@ -292,6 +292,8 @@ class GraphBuilder:
"""
for node in self.nodes:
node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
@@ -320,7 +322,7 @@ class GraphBuilder:
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
if node_instance:
# Wrap node's run method to avoid closure issues

View File

@@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
logger = get_logger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
)
@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
yield {
"event": "message",
"data": {
"chunk": final_chunk
"content": final_chunk
}
}

View File

@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
instance:
The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]).
NumberVariable, ArrayVariable[StringVariable]).
mut:
Whether the variable is mutable.
"""
@@ -152,6 +152,36 @@ class VariablePool:
return None
return var_instance
def get_instance(
self,
selector: str,
default: Any = None,
strict: bool = True
):
"""Retrieve a variable instance from the variable pool.
Args:
selector:
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
Returns:
The variable instance object if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance
def get_value(
self,
selector: str,
@@ -273,38 +303,52 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:

View File

@@ -132,24 +132,24 @@ class WorkflowExecutor:
start_time = datetime.datetime.now()
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
# Execute the workflow
try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
@@ -158,24 +158,42 @@ class WorkflowExecutor:
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
# Append messages for user and assistant
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
if input_data.get("files"):
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "user",
"content": input_data.get("files")
},
{
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# Calculate elapsed time
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
@@ -231,23 +249,23 @@ class WorkflowExecutor:
}
}
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
full_content = ''
self.stream_coordinator.update_scope_activation("sys")
@@ -272,7 +290,7 @@ class WorkflowExecutor:
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type == "node_chunk":
async for msg_event in self.event_handler.handle_node_chunk_event(data):
full_content += msg_event["data"]["chunk"]
full_content += msg_event["data"]["content"]
yield msg_event
elif event_type == "node_error":
@@ -295,12 +313,12 @@ class WorkflowExecutor:
self.graph,
self.execution_context.checkpoint_config
):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
# Flush any remaining chunks
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
result = graph.get_state(self.execution_context.checkpoint_config).values
@@ -308,21 +326,39 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
# Append messages for user and assistant
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
if input_data.get("files"):
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "user",
"content": input_data.get("files")
},
{
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
logger.info(
f"Workflow execution completed (streaming), "
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
)
yield {

View File

@@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db
from app.db import get_db_context
from app.models import AppRelease
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑
Args:
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
with get_db_context() as db:
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
)
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
)
response = result.get("response", "")
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")

View File

@@ -88,6 +88,8 @@ class AssignerNode(BaseNode):
await operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
await operator.remove_last()
case AssignmentOperator.EXTEND:
await operator.extend()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel):
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
# name: str | None = Field(
# default=None,
# description="节点名称(显示名称),如果不设置则使用节点 ID"
# )
#
# description: str | None = Field(
# default=None,
# description="节点描述,说明节点的作用"
# )
#
# tags: list[str] = Field(
# default_factory=list,
# description="节点标签,用于分类和搜索"
# )
class Config:
"""Pydantic 配置"""

View File

@@ -1,6 +1,8 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from functools import cached_property
from typing import Any, AsyncGenerator
@@ -10,8 +12,11 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.services.multimodal_service import PROVIDER_STRATEGIES
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -196,7 +201,7 @@ class BaseNode(ABC):
timeout=timeout
)
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
# Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result)
@@ -219,7 +224,7 @@ class BaseNode(ABC):
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)."
)
@@ -230,7 +235,7 @@ class BaseNode(ABC):
variable_pool,
)
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution failed: {e}",
exc_info=True,
@@ -307,10 +312,10 @@ class BaseNode(ABC):
"done": done
})
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
@@ -337,7 +342,7 @@ class BaseNode(ABC):
yield state_update | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)",
@@ -347,7 +352,7 @@ class BaseNode(ABC):
)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output
@@ -548,9 +553,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict
)
@@ -614,16 +619,48 @@ class BaseNode(ABC):
return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
async def process_message(
provider: str,
is_omni: bool,
content: str | dict | FileObject,
enable_file=False
) -> list | str | None:
if isinstance(content, dict):
content = FileObject(
type=content.get("type"),
url=content.get("url"),
transfer_method=content.get("transfer_method"),
origin_file_type=content.get("origin_file_type"),
file_id=content.get("file_id"),
is_file=True
)
if isinstance(content, str):
if enable_file:
return {"text": content}
return [{"type": "text", "text": content}]
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
elif isinstance(content, FileObject):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
file_obj = FileInput(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
origin_file_type=content.origin_file_type,
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
)
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
[file_obj]
)
content.set_content(file_obj.get_content())
if message:
content.content_cache[provider] = message
return message
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod
def process_model_output(content) -> str:
@@ -639,3 +676,12 @@ class BaseNode(ABC):
elif isinstance(content, str):
return content
return result
@staticmethod
def model_balance(model_config: ModelConfig) -> ModelApiKey:
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
raise ValueError("No active API keys available for model")
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
return api_keys[0]

View File

@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
__all__ = [
# 基础类
@@ -47,5 +48,6 @@ __all__ = [
"ToolNodeConfig",
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig",
"CodeNodeConfig"
"CodeNodeConfig",
"NoteNodeConfig"
]

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
self.child_variable_pool.variables["conv"]
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
if config.flatten:
outputs['output'] = config.output_type
else:
outputs['output'] = VariableType.ARRAY_STRING
outputs['output'] = VariableType.NESTED_ARRAY
else:
outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs

View File

@@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig):
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
# # 输出变量定义
# output_variables: list[VariableDefinition] = Field(
# default_factory=lambda: [
# VariableDefinition(
# name="output",
# type=VariableType.STRING,
# description="工作流的最终输出"
# )
# ],
# description="输出变量定义(自动生成,通常不需要修改)"
# )
class Config:
json_schema_extra = {

View File

@@ -24,6 +24,9 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown"
NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
@@ -61,6 +64,7 @@ class AssignmentOperator(StrEnum):
APPEND = "append"
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
EXTEND = "extend"
class HttpRequestMethod(StrEnum):

View File

@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
from app.core.workflow.variable.base_variable import FileObject
class HttpAuthConfig(BaseModel):
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
description="Http response headers"
)
files: list[FileObject] = Field(
default_factory=list,
description="List of files",
)
output: str = Field(
default="SUCCESS",
description="HTTP response body",

View File

@@ -1,22 +1,146 @@
import asyncio
import json
import logging
import mimetypes
import uuid
import imghdr
from email.message import Message
from typing import Any, Callable, Coroutine
import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout
import magic
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.utils.file_processer import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod
logger = logging.getLogger(__file__)
class HttpResponse:
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
self._is_file: bool | None = None
@property
def content_type(self) -> str:
return self.headers.get("content-type", "")
@property
def content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()
msg["content-disposition"] = content_disposition
return msg
return None
@property
def is_file(self) -> bool:
if self._is_file is not None:
return self._is_file
content_type = self.content_type.split(";")[0].strip().lower()
parsed_content_disposition = self.content_disposition
if parsed_content_disposition:
disp_type = parsed_content_disposition.get_content_disposition()
filename = parsed_content_disposition.get_filename()
if disp_type == "attachment" or filename:
self._is_file = True
return True
if content_type.startswith("text/") and "csv" not in content_type:
return False
if content_type.startswith("application/"):
if any(
text_type in content_type
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
):
self._is_file = False
return False
try:
content_sample = self.response.content[:1024]
content_sample.decode("utf-8")
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
if any(marker in content_sample for marker in text_markers):
return False
except UnicodeDecodeError:
self._is_file = True
return True
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
if main_type:
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
return self._is_file
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
return self._is_file
@property
def is_image(self):
if self.is_file:
kind = imghdr.what(None, h=self.response.content)
return kind is not None
return False
@property
def url(self) -> str:
return str(self.response.url)
@property
def body(self) -> str:
if self.is_file:
return f"{'!' if self.is_image else ''}[file]({self.url})"
return self.response.text
@staticmethod
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
mime = magic.from_buffer(file_bytes, mime=True)
if mime.startswith("image"):
return FileType.IMAGE, mime
elif mime.startswith("video"):
return FileType.VIDEO, mime
elif mime.startswith("audio"):
return FileType.AUDIO, mime
elif mime in ["application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"text/plain"]:
return FileType.DOCUMENT, mime
return None, None
@property
def files(self) -> list[FileObject]:
file_type, mime_type = self.get_file_type(self.response.content)
origin_file_type = mime_to_file_type(mime_type)
if self.is_file and file_type and origin_file_type:
file_obj = FileObject(
type=file_type,
url=self.url,
transfer_method=TransferMethod.REMOTE_URL.value,
origin_file_type=origin_file_type,
file_id=None,
is_file=True
)
file_obj.set_content(self.response.content)
return [
file_obj
]
return []
class HttpRequestNode(BaseNode):
"""
HTTP Request Workflow Node.
@@ -42,6 +166,7 @@ class HttpRequestNode(BaseNode):
"body": VariableType.STRING,
"status_code": VariableType.NUMBER,
"headers": VariableType.OBJECT,
"files": VariableType.ARRAY_FILE,
"output": VariableType.STRING
}
@@ -115,7 +240,7 @@ class HttpRequestNode(BaseNode):
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
@@ -135,16 +260,35 @@ class HttpRequestNode(BaseNode):
))
case HttpContentType.FROM_DATA:
data = {}
content["files"] = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file":
# TODO: File support (Feature)
pass
content["files"][self._render_template(item.key, variable_pool)] = (
uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data
case HttpContentType.BINARY:
# TODO: File support (Feature)
pass
content["files"] = []
file_instence = variable_pool.get_instance(self.typed_config.body.data)
if isinstance(file_instence, ArrayVariable):
for v in file_instence.value:
if isinstance(v, FileVariable):
content["files"].append(
(
"files", (uuid.uuid4().hex, await v.get_content())
)
)
elif isinstance(file_instence, FileVariable):
content["files"].append(
(
"file", (uuid.uuid4().hex, await file_instence.get_content())
)
)
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), variable_pool
@@ -207,14 +351,16 @@ class HttpRequestNode(BaseNode):
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool)
**(await self._build_content(variable_pool))
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp)
return HttpRequestNodeOutput(
body=resp.text,
body=response.body,
status_code=resp.status_code,
headers=resp.headers,
files=response.files
).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}")
@@ -236,5 +382,5 @@ class HttpRequestNode(BaseNode):
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"
return {"output": "ERROR"}
raise RuntimeError("http request failed")

View File

@@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
)
knowledge_bases: list[KnowledgeBaseConfig] = Field(
...,
default_factory=list,
description="Knowledge base config"
)

View File

@@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode):
RuntimeError: If no valid knowledge base is found or access is denied.
"""
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
if not self.typed_config.knowledge_bases:
return []
query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases

View File

@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
api_config = self.model_balance(config)
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params=extra_params
extra_params=extra_params,
is_omni=is_omni
),
type=ModelType(model_type)
)
@@ -151,39 +153,53 @@ class LLMNode(BaseNode):
if role == "system":
messages.append({
"role": "system",
"content": content
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
"content": content
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
"content": content
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
"content": content
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
if content:
file_content.append(content)
file_content.extend(content)
if messages and messages[-1]["role"] == 'user':
messages[-1]['content'] = [messages[-1]["content"]] + file_content
messages[-1]['content'] = messages[-1]["content"] + file_content
else:
messages.append({"role": "user", "content": file_content})
if self.typed_config.memory.enable:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
history_message = []
for message in state["messages"][-self.typed_config.memory.window_size:]:
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
{"role": message["role"], "content": file_content}
)
else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)

View File

@@ -123,10 +123,10 @@ class NodeFactory:
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
raise ValueError(f"Unsupported node type: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod

View File

@@ -0,0 +1,12 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class NoteNodeConfig(BaseNodeConfig):
author: str = Field(default="", description="author")
text: str = Field(default="", description="note content")
width: int = Field(default=80)
height: int = Field(default=80)
theme: str = Field(default="blue")
show_author: bool = Field(default=True)

View File

@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
api_config = self.model_balance(config)
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
llm = RedBearLLM(
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni
),
type=ModelType(model_type)
)

View File

@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
api_config = self.model_balance(config)
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
base_url = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
return RedBearLLM(
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
provider=provider,
api_key=api_key,
base_url=base_url,
is_omni=is_omni
),
type=ModelType(model_type)
)

View File

@@ -3,7 +3,6 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class StartNodeConfig(BaseNodeConfig):
@@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig):
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
# # 输出变量定义
# output_variables: list[VariableDefinition] = Field(
# default_factory=lambda: [
# VariableDefinition(
# name="message",
# type=VariableType.STRING,
# description="用户输入的消息"
# ),
# VariableDefinition(
# name="conversation_vars",
# type=VariableType.OBJECT,
# description="会话级变量"
# ),
# VariableDefinition(
# name="execution_id",
# type=VariableType.STRING,
# description="执行 ID"
# ),
# VariableDefinition(
# name="conversation_id",
# type=VariableType.STRING,
# description="会话 ID"
# ),
# VariableDefinition(
# name="workspace_id",
# type=VariableType.STRING,
# description="工作空间 ID"
# ),
# VariableDefinition(
# name="user_id",
# type=VariableType.STRING,
# description="用户 ID"
# )
# ],
# description="输出变量定义(自动生成,通常不需要修改)"
# )
class Config:
json_schema_extra = {

Some files were not shown because too many files have changed in this diff Show More