Compare commits

...

74 Commits

Author SHA1 Message Date
Mark
524aed19d4 Revert "model and statistic" 2026-01-28 14:30:27 +08:00
Mark
7e56c09620 Merge pull request #218 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
model and statistic
2026-01-28 13:34:48 +08:00
lixinyue11
2e7f6afe3f Fix/memory bug fix (#217)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

* 修复宿主列表获取memory_config_idBUG

* config_id做映射

* config_id做映射

* config_id做映射+1

* config_id做映射+1

* config_id做映射+1

* 应用层memory_content->memory_config

* 应用层memory_content->memory_config

* 应用层memory_content->memory_config

* 统一字段为config_id_old

* 统一字段为config_id_old

* 统一字段为config_id_old

* 统一字段为config_id_old

* memory_content暂时不修改

* memory_content暂时不修改

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-28 11:58:10 +08:00
Timebomb2018
9a4b1f0937 feat(model and app statistic): 1. Optimize the model list; 2. Increase the model combination; 3. Add a model square; 4. Add application management statistics 2026-01-28 11:42:45 +08:00
Timebomb2018
e5e914903c feat(model and app statistic): 1. Optimize the model list; 2. Increase the model combination; 3. Add a model square; 4. Add application management statistics 2026-01-28 11:04:46 +08:00
lixinyue11
7ba443afa5 Fix/memory bug fix (#215)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

* 修复宿主列表获取memory_config_idBUG

* config_id做映射

* config_id做映射

* config_id做映射+1

* config_id做映射+1

* config_id做映射+1

* 应用层memory_content->memory_config

* 应用层memory_content->memory_config

* 应用层memory_content->memory_config

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-28 11:01:58 +08:00
Timebomb2018
2862db3534 feat(model and app statistic): 1. Optimize the model list; 2. Increase the model combination; 3. Add a model square; 4. Add application management statistics 2026-01-28 10:15:51 +08:00
lixinyue11
bf3e30dac0 Fix/memory bug fix (#212)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

* 修复宿主列表获取memory_config_idBUG

* config_id做映射

* config_id做映射

* config_id做映射+1

* config_id做映射+1

* config_id做映射+1

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-28 10:07:32 +08:00
lixinyue11
375660f232 Fix/memory bug fix (#211)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

* 修复宿主列表获取memory_config_idBUG

* config_id做映射

* config_id做映射

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-27 20:26:14 +08:00
Mark
f6031baee4 Merge pull request #210 from SuanmoSuanyangTechnology/fix/workflow-stream
fix(workflow): fix activation and branch control issues in streaming output
2026-01-27 20:09:48 +08:00
Eternity
c818ba7bc7 perf(workflow): make memory configuration backward compatible 2026-01-27 19:26:50 +08:00
Eternity
8fb9e779a6 feat(workflow): store token usage in message table 2026-01-27 18:52:51 +08:00
Eternity
c5a794f1b5 perf(workflow): enhance streaming output node activation performance 2026-01-27 18:39:47 +08:00
lixiangcheng1
3aa2cdd754 Merge branch 'feature/knowledge_lxc' into develop 2026-01-27 18:30:56 +08:00
lixiangcheng1
d93d52cf10 [fix]remove aspose-slides 2026-01-27 18:30:27 +08:00
Eternity
2abbd5a7fb fix(workflow): fix streaming output error when variable is not a string 2026-01-27 18:16:53 +08:00
Eternity
2a10e9f7ee style(workflow): enforce PEP8 style and remove redundant imports 2026-01-27 17:51:27 +08:00
Eternity
166d05afe9 fix(workflow): fix function cache not taking effect and potential list index overflow 2026-01-27 17:41:18 +08:00
Eternity
2eff8d1962 fix(workflow): fix activation and branch control issues in streaming output 2026-01-27 17:23:53 +08:00
Mark
93c9e76c4b [add] migration script 2026-01-27 15:31:29 +08:00
Mark
021cb09b82 Merge branch 'feature/plugin' into develop 2026-01-27 15:14:49 +08:00
Mark
28e6939884 [modify] file local server url 2026-01-27 15:06:50 +08:00
lixinyue11
8847039d76 Fix/memory bug fix (#209)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

* 修复宿主列表获取memory_config_idBUG

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-27 14:36:37 +08:00
Mark
2694576a32 [add] plugin system and base sso module 2026-01-27 14:04:44 +08:00
yingzhao
e4f10670f6 Merge pull request #208 from SuanmoSuanyangTechnology/feature/codeNode_zy
fix(web): remove URI decode and encode
2026-01-27 13:51:55 +08:00
zhaoying
1324ba3a49 fix(web): remove URI decode and encode 2026-01-27 13:47:55 +08:00
lixinyue11
73c7810310 Fix/memory bug fix (#207)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* 检查需要更改的格式问题

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-27 11:45:14 +08:00
乐力齐
d160076267 Fix/redbear benchmark (#205)
* Refactor: Move evaluation folder to redbear-mem-benchmark submodule

* [changes]Update submodule reference

* Refactor: Move evaluation folder to redbear-mem-benchmark submodule

* [changes]Update submodule reference

* Remove duplicate evaluation submodule, use redbear-mem-benchmark instead
2026-01-27 11:44:50 +08:00
yingzhao
ed8c1c7c19 Merge pull request #206 from SuanmoSuanyangTechnology/feature/codeNode_zy
feat(web): workflow add code node
2026-01-27 11:41:12 +08:00
yingzhao
159c8d1ff9 Merge branch 'develop' into feature/codeNode_zy 2026-01-27 11:40:54 +08:00
Mark
8932d455d8 Merge pull request #202 from SuanmoSuanyangTechnology/feature/workflow-code
Feature/workflow code
2026-01-27 11:40:18 +08:00
zhaoying
3af183f6c3 feat(web): workflow add code node 2026-01-27 11:37:17 +08:00
乐力齐
c3ea3b751b delete benchmark-test (#204)
* Refactor: Move evaluation folder to redbear-mem-benchmark submodule

* [changes]Restore .gitmodules
2026-01-26 20:30:07 +08:00
Mark
e2c67d0c5b Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop 2026-01-26 19:19:59 +08:00
Mark
87731090ca [modify] migration script 2026-01-26 19:19:41 +08:00
乐力齐
80ca247435 Refactor/benchmark test (#196)
* [changes]refactor locomo_test

* [fix]Fix the circular import of ModelParameters

* [changes]The benchmark test can run stably.

* [fix]Complete end-to-end LoCoMo repair

* [fix]Complete the end-to-end longmemeval and memsciqa fixes

* [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect.

* [changes]refactor locomo_test

* [fix]Fix the circular import of ModelParameters

* [changes]The benchmark test can run stably.

* [fix]Complete end-to-end LoCoMo repair

* [fix]Complete the end-to-end longmemeval and memsciqa fixes

* [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect.

* [changes]Benchmark test adaptation for end_user_id

* [changes]refactor locomo_test

* [fix]Fix the circular import of ModelParameters

* [changes]The benchmark test can run stably.

* [fix]Complete end-to-end LoCoMo repair

* [fix]Complete the end-to-end longmemeval and memsciqa fixes

* [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect.

* [fix]Complete the end-to-end longmemeval and memsciqa fixes

* [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect.

* [changes]Benchmark test adaptation for end_user_id
2026-01-26 19:05:20 +08:00
lixinyue11
a5b8d3afa5 Fix/memory bug fix (#200)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

* user_id->显示为config_id_old传输

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-26 19:05:07 +08:00
Eternity
1f615a06ad fix(sandbox): treat non-zero exit codes as errors instead of relying only on stderr 2026-01-26 18:50:22 +08:00
yingzhao
4123560a98 Merge pull request #203 from SuanmoSuanyangTechnology/feature/workflow_runtime_zy
Feature/workflow runtime zy
2026-01-26 18:42:27 +08:00
zhaoying
5267bd60a5 fix(web): iteration's variable add parameter-extractor node 2026-01-26 18:40:28 +08:00
zhaoying
f76bffb482 fix(web): KnowledgeConfigModal bugfix 2026-01-26 18:32:18 +08:00
yingzhao
51185c83c9 Merge pull request #201 from SuanmoSuanyangTechnology/feature/memoryApi_zy
feat(web): update read_all_config select valueKey
2026-01-26 17:54:43 +08:00
Eternity
f1f887faae feat(workflow): Add a new node for executing code 2026-01-26 17:51:31 +08:00
zhaoying
46f0f3cee9 feat(web): update read_all_config select valueKey 2026-01-26 17:43:25 +08:00
lixinyue11
ebc41b2eec Fix/memory bug fix (#199)
* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 图谱数据量限制数量去掉

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 用户详情优化

* 读取的接口,去掉全局锁

* 输出数组

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化1.0(优化隐私输出、时间检索)

* 反思优化测试接口

* 反思优化测试接口

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 读取接口内层嵌套BUG修复

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段

* 把group_id替换end_user_id

* 把group_id替换end_user_id_

* 把group_id替换end_user_id_

* config_config替换成memory_config

* config_config替换成memory_config

* [fix]Fix the memory interface to use end_user_id.

* config_config替换成memory_config

* config_config替换成memory_config

* config_config替换成memory_config

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID

* config_id字段改成UUID,与develop校对恢复

* 检查项目,修复group_id的遗留问题

* 检查项目,修复group_id的遗留问题

* 解决冲突

* 解决冲突

* end_user_id清理干净

* end_user_id清理干净

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 修复遗留合并BUG

* 感知meta_data字段BUG修复

* user_id->现实为config_id_old

* user_id->显示为config_id_old传输

---------

Co-authored-by: lanceyq <1982376970@qq.com>
2026-01-26 17:22:48 +08:00
Eternity
3b4b474ce8 fix(sandbox): prevent imports from being blocked when network is disabled 2026-01-26 16:32:58 +08:00
yingzhao
4534e46811 Merge pull request #198 from SuanmoSuanyangTechnology/feature/workflow_runtime_zy
fix(web):  handleSSE bugfix
2026-01-26 16:01:27 +08:00
zhaoying
7bfa7b3f02 fix(web): handleSSE bugfix 2026-01-26 16:00:47 +08:00
yingzhao
1cc34d8e62 Merge pull request #197 from SuanmoSuanyangTechnology/feature/workflow_runtime_zy
feat(web): add workflow runtime info
2026-01-26 15:48:35 +08:00
zhaoying
2eff6b2e9d feat(web): add workflow runtime info 2026-01-26 15:46:28 +08:00
Mark
b046411302 [modify] migration script 2026-01-26 15:39:35 +08:00
Mark
6ab65b3626 Merge pull request #195 from SuanmoSuanyangTechnology/feature/workflow-code
Add SSE-based exception streaming and sandbox support for workflow
2026-01-26 14:30:53 +08:00
Mark
cf321f9b09 Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop 2026-01-26 14:26:40 +08:00
Mark
8228d38859 [add] migration script 2026-01-26 14:26:32 +08:00
yingzhao
c2e3110fa2 Merge pull request #194 from SuanmoSuanyangTechnology/feature/memoryApi_zy
feat(web): memory related interface parameter transfer adjustment
2026-01-26 12:56:52 +08:00
Eternity
85681db7b7 perf(workflow): update standard node output structure 2026-01-26 12:28:40 +08:00
Eternity
1fc04c37d3 perf(sandbox): optimize code encryption handling 2026-01-26 12:22:54 +08:00
Eternity
0fd8a122fb feat(workflow): emit SSE events for node exception output 2026-01-26 12:00:55 +08:00
Eternity
e3b6ede992 feat(sandbox): add Python 3 code execution sandbox support 2026-01-26 11:54:38 +08:00
lixinyue11
3601737869 Fix/memory bug fix (#171) 2026-01-26 11:53:34 +08:00
zhaoying
4f4f55d67f feat(web): memory related interface parameter transfer adjustment 2026-01-26 11:04:30 +08:00
Ke Sun
714c624dc6 Merge branch 'main' into develop 2026-01-25 12:44:34 +08:00
lixinyue11
1919580759 Fix/memory mcp2 1 (#190)
* 优化快速检索的回复内容

* 优化快速检索的回复内容

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* LLM生存缺少config_id认证,修复BUG

* LLM生存缺少config_id认证,修复BUG

* LLM生存缺少config_id认证,修复BUG

* 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问

* 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问

* 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问
2026-01-23 17:12:21 +08:00
Mark
b27ffe57e6 Merge pull request #189 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(home page): version description update
2026-01-23 17:03:29 +08:00
Timebomb2018
c115bcde54 feat(home page): version description update 2026-01-23 16:58:55 +08:00
lixinyue11
313f19eba4 Fix/memory mcp2 1 (#188)
* 优化快速检索的回复内容

* 优化快速检索的回复内容

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* LLM生存缺少config_id认证,修复BUG

* LLM生存缺少config_id认证,修复BUG

* LLM生存缺少config_id认证,修复BUG
2026-01-23 14:49:44 +08:00
yingzhao
c6bcf53fea Merge pull request #186 from SuanmoSuanyangTechnology/feature/ui_zy
fix(web): workflow's variables bugfix
2026-01-23 14:02:13 +08:00
lixinyue11
86812b34d1 Fix/memory mcp2 1 (#185)
* 优化快速检索的回复内容

* 优化快速检索的回复内容

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复

* 路径的BUG修复
2026-01-23 13:57:27 +08:00
lixinyue11
15f9c49418 Fix/memory mcp2 1 (#184)
* 优化快速检索的回复内容

* 优化快速检索的回复内容
2026-01-23 12:21:54 +08:00
乐力齐
6e18c92a13 Fix/optimize inerface (#183)
* [changes]Optimize the time consumption of the "/end_users" interface

* [fix]Optimize the time consumption of the "/hot_memory_tags" interface

* [changes]Optimize the time consumption of the "/end_users" interface

* [fix]Optimize the time consumption of the "/hot_memory_tags" interface

* [changes]Improve the code based on AI review
2026-01-23 12:21:28 +08:00
乐力齐
7870c6c33f Fix/interface home (#182)
* [fix]Fix the interface for statistics of recent activities and applications

* [changes]Modify the code based on the AI review
1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter.
2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls.

* [fix]Fix the interface for statistics of recent activities and applications

* [changes]Modify the code based on the AI review
1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter.
2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls.
2026-01-23 10:50:24 +08:00
yujiangping
45adb9627a Merge branch 'feature/knowledgeBase_yjp' into develop 2026-01-22 20:59:36 +08:00
yujiangping
7219274d94 Merge branch 'release/v0.2.1' into develop 2026-01-22 20:21:29 +08:00
yujiangping
51680b7077 Merge branch 'feature/knowledgeBase_yjp' into develop 2026-01-22 16:44:58 +08:00
248 changed files with 5947 additions and 10889 deletions

3
.gitignore vendored
View File

@@ -35,3 +35,6 @@ nltk_data/
tika-server*.jar*
cl100k_base.tiktoken
libssl*.deb
sandbox/lib/seccomp_python/target
sandbox/lib/seccomp_nodejs/target

0
api/app/__init__.py Normal file
View File

View File

@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional
from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success
from app.dependencies import get_current_user
@@ -32,11 +33,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID")
config_id: UUID = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID")
config_id: UUID = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: int = Query(..., description="配置ID"),
config_id: UUID = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):

View File

@@ -53,7 +53,7 @@ async def get_emotion_tags(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"emotion_type": request.emotion_type,
"start_date": request.start_date,
"end_date": request.end_date,
@@ -63,7 +63,7 @@ async def get_emotion_tags(
# 调用服务层
data = await emotion_service.get_emotion_tags(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
emotion_type=request.emotion_type,
start_date=request.start_date,
end_date=request.end_date,
@@ -73,7 +73,7 @@ async def get_emotion_tags(
api_logger.info(
"情绪标签统计获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", []))
}
@@ -84,7 +84,7 @@ async def get_emotion_tags(
except Exception as e:
api_logger.error(
f"获取情绪标签统计失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"emotion_type": request.emotion_type,
"limit": request.limit
}
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
# 调用服务层
data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
emotion_type=request.emotion_type,
limit=request.limit
)
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
api_logger.info(
"情绪词云数据获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"total_keywords": data.get("total_keywords", 0)
}
)
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
except Exception as e:
api_logger.error(
f"获取情绪词云数据失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -159,21 +159,21 @@ async def get_emotion_health(
api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"time_range": request.time_range
}
)
# 调用服务层
data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
time_range=request.time_range
)
api_logger.info(
"情绪健康指数获取成功",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"health_score": data.get("health_score", 0),
"level": data.get("level", "未知")
}
@@ -186,7 +186,7 @@ async def get_emotion_health(
except Exception as e:
api_logger.error(
f"获取情绪健康指数失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
"""获取个性化情绪建议(从缓存读取)
Args:
request: 包含 group_id 和可选的 config_id
request: 包含 end_user_id 和可选的 config_id
db: 数据库会话
current_user: 当前用户
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"config_id": request.config_id
}
)
# 从缓存获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.group_id,
end_user_id=request.end_user_id,
db=db
)
if data is None:
# 缓存不存在或已过期
api_logger.info(
f"用户 {request.group_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id}
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
extra={"end_user_id": request.end_user_id}
)
return fail(
BizCode.NOT_FOUND,
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
api_logger.info(
"个性化建议获取成功(缓存)",
extra={
"group_id": request.group_id,
"end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(

View File

@@ -310,7 +310,7 @@ async def get_file_url(
try:
if permanent:
# Generate permanent URL (no expiration check)
server_url = f"http://{settings.SERVER_IP}:8000/api"
server_url = settings.FILE_LOCAL_SERVER_URL
url = f"{server_url}/storage/permanent/{file_id}"
return success(
data={

View File

@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
user_id: str,
end_user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -137,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
@@ -146,20 +146,20 @@ async def get_preference_tags(
Returns:
List of preference tags from cache
"""
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -192,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
user_id: str,
end_user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
include_history: Whether to include historical trend data (ignored for cached data)
Returns:
Four-dimension personality portrait from cache
"""
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id)
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
user_id: str,
end_user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
include_trends: Whether to include trend analysis data (ignored for cached data)
Returns:
Interest area distribution from cache
"""
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
user_id: str,
end_user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -309,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache.
Args:
user_id: Target user ID
end_user_id: Target end user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
@@ -317,20 +317,20 @@ async def get_behavior_habits(
Returns:
List of behavioral habits from cache
"""
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -368,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)

View File

@@ -125,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Response with write operation status
@@ -160,19 +160,18 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.group_id,
messages_list, # 传递结构化消息列表
user_input.end_user_id,
messages_list,
config_id,
db,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -196,7 +195,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and group_id
user_input: Write request containing message and end_user_id
Returns:
Task ID for tracking async operation
@@ -226,10 +225,10 @@ async def write_server_async(
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
@@ -255,16 +254,14 @@ async def read_server(
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and group_id
user_input: Read request with message, history, search_switch, and end_user_id
Returns:
Response with query answer
"""
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
@@ -279,12 +276,13 @@ async def read_server(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.group_id,
user_input.end_user_id,
user_input.message,
user_input.history,
user_input.search_switch,
@@ -295,17 +293,20 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -403,7 +404,7 @@ async def read_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
@@ -447,7 +448,7 @@ async def get_read_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -524,7 +525,7 @@ async def get_write_task_result(
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -578,16 +579,16 @@ async def status_type(
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and group_id
user_input: Request containing user message and end_user_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
@@ -595,11 +596,11 @@ async def status_type(
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
last_user_message,
user_input.config_id,
@@ -624,7 +625,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
@@ -697,7 +698,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user)
):
"""
获取工作空间下Popular Memory Tags,包含:
获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签

View File

@@ -49,63 +49,134 @@ async def get_workspace_end_users(
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表
获取工作空间的宿主列表(高性能优化版本 v2
返回格式与原 memory_list 接口中的 end_users 字段相同,
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
优化策略:
1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
"""
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users已优化为批量查询
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
memory_configs_map = {}
if end_user_ids:
if not end_users:
api_logger.info("工作空间下没有宿主")
# 缓存空结果,避免重复查询
try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回
return {}
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = []
for end_user in end_users:
memory_num = {}
if current_workspace_type == "neo4j":
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
elif current_workspace_type == "rag":
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 只保留需要的字段,移除 error 字段(如果有)
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
config_info = memory_configs_map.get(user_id, {})
result.append({
'end_user': {
'id': user_id,
'other_name': end_user.other_name
},
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_config': {
"memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name")
}
)
})
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")

View File

@@ -11,6 +11,7 @@
"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -106,7 +107,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
group_id=end_user_id, # 服务层方法的参数名是 group_id
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: int,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -236,7 +237,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -246,7 +247,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。
Args:
group_id: 组ID即 end_user_id可选
end_user_id: 组ID即 end_user_id可选
current_user: 当前用户
db: 数据库会话
@@ -260,20 +261,20 @@ async def get_forgetting_stats(
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id
# 如果提供了 end_user_id通过它获取 config_id
config_id = None
if group_id:
if end_user_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -283,14 +284,14 @@ async def get_forgetting_stats(
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}"
f"end_user_id={end_user_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
group_id=group_id,
end_user_id=end_user_id,
config_id=config_id
)

View File

@@ -27,27 +27,27 @@ router = APIRouter(
)
@router.get("/{group_id}/count", response_model=ApiResponse)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
group_id: ID of the user group (usually end_user_id in this context)
end_user_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id)
count_stats = service.get_memory_count(end_user_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id)
visual_memory = service.get_latest_visual_memory(end_user_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}")
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id)
audio_memory = service.get_latest_audio_memory(end_user_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}")
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id)
text_memory = service.get_latest_text_memory(end_user_id)
if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}")
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
return success(
data=None,
msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group.
Args:
group_id: ID of the user group
end_user_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}"
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
)
try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query)
timeline_data = service.get_time_line(end_user_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"error={str(e)}"
)
return fail(

View File

@@ -1,6 +1,7 @@
import asyncio
import time
import uuid
from uuid import UUID
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
@@ -11,7 +12,7 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
@@ -50,7 +51,7 @@ async def save_reflection_config(
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
data_config = DataConfigRepository.update_reflection_config(
memory_config = MemoryConfigRepository.update_reflection_config(
db,
config_id=config_id,
enable_self_reflexion=request.reflection_enabled,
@@ -63,17 +64,17 @@ async def save_reflection_config(
)
db.commit()
db.refresh(data_config)
db.refresh(memory_config)
reflection_result={
"config_id": data_config.config_id,
"enable_self_reflexion": data_config.enable_self_reflexion,
"iteration_period": data_config.iteration_period,
"reflexion_range": data_config.reflexion_range,
"baseline": data_config.baseline,
"reflection_model_id": data_config.reflection_model_id,
"memory_verify": data_config.memory_verify,
"quality_assessment": data_config.quality_assessment}
"config_id": memory_config.config_id,
"enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": memory_config.iteration_period,
"reflexion_range": memory_config.reflexion_range,
"baseline": memory_config.baseline,
"reflection_model_id": memory_config.reflection_model_id,
"memory_verify": memory_config.memory_verify,
"quality_assessment": memory_config.quality_assessment}
return success(data=reflection_result, msg="反思配置成功")
@@ -111,14 +112,14 @@ async def start_workspace_reflection(
reflection_results = []
for data in result['apps_detailed_info']:
if data['data_configs'] == []:
if data['memory_configs'] == []:
continue
releases = data['releases']
data_configs = data['data_configs']
memory_configs = data['memory_configs']
end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users):
for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况
print(base['config'])
try:
@@ -156,14 +157,14 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: int,
config_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询data_config表中的反思配置信息"""
"""通过config_id查询memory_config表中的反思配置信息"""
try:
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 构建返回数据
reflection_config = {
"config_id": result.config_id,
@@ -191,7 +192,7 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: int,
config_id: UUID,
language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@@ -200,8 +201,8 @@ async def reflection_run(
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
# 使用MemoryConfigRepository查询反思配置
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View File

@@ -1,5 +1,6 @@
import os
from typing import Optional
from uuid import UUID
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
@@ -160,7 +161,7 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: str,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -232,7 +233,7 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: str,
config_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -420,15 +421,95 @@ async def get_hot_memory_tags_api(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
"""
获取热门记忆标签带Redis缓存
缓存策略:
- 缓存键workspace_id + limit
- 过期时间5分钟300秒
- 缓存命中:~50ms
- 缓存未命中:~600-800ms取决于LLM速度
"""
workspace_id = current_user.current_workspace_id
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
try:
# 尝试从Redis缓存获取
from app.aioRedis import aio_redis_get, aio_redis_set
import json
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
try:
data = json.loads(cached_result)
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串
try:
cache_data = json.dumps(result, ensure_ascii=False)
await aio_redis_set(cache_key, cache_data, expire=300)
api_logger.info(f"Cached result for key: {cache_key}")
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user),
) -> dict:
"""
清除热门标签缓存
用于:
- 手动刷新数据
- 调试和测试
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try:
from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
result = await aio_redis_delete(cache_key)
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
return success(
data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),

View File

@@ -20,18 +20,18 @@ router = APIRouter(
)
@router.get("/{group_id}/count", response_model=ApiResponse)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{group_id}/conversations", response_model=ApiResponse)
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
def get_conversations(
group_id: uuid.UUID,
end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -39,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group.
Args:
group_id (UUID): The group identifier.
end_user_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
@@ -53,7 +53,7 @@ def get_conversations(
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
group_id
end_user_id
)
return success(data=[
{
@@ -63,7 +63,7 @@ def get_conversations(
], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse)
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
@@ -100,7 +100,7 @@ def get_messages(
return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse)
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),

View File

@@ -317,9 +317,12 @@ async def chat(
appid = share.app_id
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first()
app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
).first()
if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -135,27 +135,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id
end_user_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}"
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
)
try:
if group_id:
if end_user_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
# 构建响应
result = {
"end_user_id": group_id,
"end_user_id": end_user_id,
"insight_success": insight_result["success"],
"summary_success": summary_result["success"],
"errors": []
@@ -175,9 +175,9 @@ async def generate_cache_api(
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存")
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")

View File

@@ -54,7 +54,7 @@ async def create_workflow_config(
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -214,7 +214,7 @@ async def delete_workflow_config(
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -259,7 +259,7 @@ async def validate_workflow_config(
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -329,7 +329,7 @@ async def get_workflow_executions(
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -389,7 +389,7 @@ async def get_workflow_execution(
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -440,7 +440,7 @@ async def run_workflow(
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
App.is_active.is_(True)
).first()
if not app:

View File

@@ -155,13 +155,13 @@ class LangChainAgent:
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# group_id=end_user_end,
# end_user_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
@@ -179,7 +179,7 @@ class LangChainAgent:
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
@@ -188,7 +188,7 @@ class LangChainAgent:
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
@@ -204,20 +204,20 @@ class LangChainAgent:
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
@@ -228,7 +228,7 @@ class LangChainAgent:
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"

View File

@@ -9,6 +9,25 @@ load_dotenv()
class Settings:
# ========================================================================
# Deployment Mode Configuration
# ========================================================================
# community: 社区版(开源,功能受限)
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -72,6 +91,10 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
@@ -107,6 +130,7 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
@@ -184,7 +208,7 @@ class Settings:
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))

View File

@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {}
data = []
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config
async def rag_knowledge(state,question):
kb_config = await rag_config(state)
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '')
end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": question,
"return_raw_results": True
}
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id
# 从state中获取end_user_id
import time
start=time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(group_id)
search_params = { "group_id": group_id, "return_raw_results": True }
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题

View File

@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
db_session = next(get_db())
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id)
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '')
group_id = state.get("group_id", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
user_id=group_id,
user_id=end_user_id,
query=data,
apply_id=group_id,
group_id=group_id,
apply_id=end_user_id,
end_user_id=end_user_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
group_id=state.get("group_id", '')
end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
search_params = {
"group_id": group_id,
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve_info_str='\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str,
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
aimessages=await summary_llm(state,history,data,
'summary_prompt.jinja2','summary',SummaryResponse,0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= {
"status": "success",
"summary_result": "没有相关数据",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}

View File

@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})

View File

@@ -1,23 +1,24 @@
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState:
"""
Write data to the database/file system.
Args:
state: WriteState containing messages, group_id, and memory_config
state: WriteState containing messages, end_user_id, and memory_config
Returns:
dict: Contains 'write_result' with status and data fields
"""
messages = state.get('messages', [])
group_id = state.get('group_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result = await write(
messages=structured_messages,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
end_user_id=end_user_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -79,7 +79,7 @@ async def make_read_graph():
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -95,9 +95,9 @@ async def main():
start=time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息
_intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str):
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
"""
def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
# 使用传入的参数或默认值
actual_group_id = group_id_param or group_id
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
results = await search_by_temporal(
group_id=actual_group_id,
end_user_id=actual_end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索
results = await search_by_keyword_temporal(
query_text=context,
group_id=group_id,
end_user_id=end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args:
memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等
**search_params: 搜索参数,包含end_user_id, limit, include等
"""
def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = {
"query_text": context,
"search_type": search_type,
"group_id": group_id or search_params.get("group_id"),
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str,
search_type: str = "hybrid",
limit: int = 10,
group_id: str = None,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
group_id: 组ID用于过滤搜索结果
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
"""
async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context,
"search_type": search_type,
"limit": limit,
"group_id": group_id,
"end_user_id": end_user_id,
"clean_output": clean_output
})

View File

@@ -14,6 +14,7 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -26,9 +27,21 @@ async def make_write_graph():
"""
Create a write graph workflow for memory operations.
The workflow directly processes messages from the initial state
and saves them to Neo4j storage.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -42,7 +55,7 @@ async def make_write_graph():
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
group_id = 'new_2025test1103' # 组ID
end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话
@@ -54,9 +67,9 @@ async def main():
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
"end_user_id": end_user_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search(
self,
group_id: str,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
end_user_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True
)
# Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
end_user_id: str = "group_1",
messages: list = None,
ref_id: str = "wyl_20251027",
config_id: str = None
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
"""
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
conversation_messages = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
end_user_id=end_user_id,
config_id=config_id
)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
return [dialog_data]

View File

@@ -1,24 +1,23 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object
write_result: dict
data:str
data: str
class ReadState(TypedDict):
"""
@@ -28,7 +27,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加
loop_count: 遍历次数
search_switch: 搜索类型开关
group_id: 组标识
end_user_id: 组标识
config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果
@@ -39,7 +38,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int
search_switch: str
group_id: str
end_user_id: str
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -0,0 +1,61 @@
# 角色
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
# 用户问题
{{query}}
# 回答指南
## 1. 仔细阅读检索信息
- 答案可能直接或间接地出现在检索信息中
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
- 第三人称描述的偏好、行为通常指用户本人
## 2. 判断信息相关性
**情况A信息匹配问题**
- 直接回答,像自然对话一样
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
**情况B信息部分相关**
- 先回答已知部分,再自然地询问更多信息
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
**情况C信息完全不相关**
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
- 使用友好的表达:
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
- "我不记得你提到过...,但你[检索到的相关信息]"
- 即使检索信息不直接回答问题,也可以自然地融入对话中
- 避免僵硬的"信息不足,无法回答"
## 3. 回答要求
- 像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或引用信息来源
- 保持友好、乐于助人的语气
- 使用与问题相同的语言回答
# 关键示例
**示例1 - 直接匹配:**
- 检索信息:"小曼会使用Python..."
- 问题:"我叫什么"
- ✓ 正确:"你叫小曼"
- ✗ 错误:"你没有告诉我你的名字"
**示例2 - 间接匹配:**
- 检索信息:"用户很喜欢吃星巴克的甜品"
- 问题:"我喜欢什么"
- ✓ 正确:"你很喜欢吃星巴克的甜品"
- ✗ 错误:"信息不足"
**示例3 - 信息不匹配(推荐做法):**
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
- 问题:"我吃过哪家面包"
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
# 重要提醒
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
- 用对话式语言表达"不知道",而非机械模板
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆

View File

@@ -0,0 +1,43 @@
{# 角色定义 #}
你是专业的问题解答专家+引导学者
{# 输入数据展示 #}
{% if data %}
## 输入数据
上下文信息:
{% for item in data.history %}
- {{ item }}
{% endfor %}
检索到的所有信息:
{% for item in data.retrieve_info %}
- {{ item }}
{% endfor %}
{% endif %}
## User Query
{{ query }}
{# 问题回答标准 #}
## 问题回答核心标准
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
**情况A信息匹配问题**
- 直接回答,像自然对话一样
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
**情况B信息部分相关**
- 先回答已知部分,再自然地询问更多信息
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
**情况C信息完全不相关**
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
- 使用友好的表达:
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
- "我不记得你提到过...,但你[检索到的相关信息]"
- 即使检索信息不直接回答问题,也可以自然地融入对话中
- 避免僵硬的"信息不足,无法回答"
{# 重要提醒 #}
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"group_id": session.get('group_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id):
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
"""
import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
group_id = data.get('group_id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen:
# 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id):
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
group_id: str
end_user_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
group_id: str,
end_user_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
end_user_id=end_user_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- end_user_id
- messages
- aimessages

View File

@@ -29,20 +29,18 @@ logger = get_agent_logger(__name__)
async def write(
user_id: str,
apply_id: str,
group_id: str,
end_user_id: str,
memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027"
@@ -51,14 +49,14 @@ async def write(
embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy
config_id = str(memory_config.config_id)
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
logger.info(f"Workspace: {memory_config.workspace_name}")
logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
logger.info(f"end_user_id ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session
with get_db_context() as db:
@@ -83,9 +81,7 @@ async def write(
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
end_user_id=end_user_id,
messages=messages,
ref_id=ref_id,
config_id=config_id,

View File

@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
def get_default_docs_path() -> str:
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
return os.path.join(project_root, "src", "analytics", "API接口.md")

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args:
tags: 原始标签列表
group_id: 用户组ID用于获取配置
end_user_id: 用户组ID用于获取配置
Returns:
筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if not config_id:
raise ValueError(
f"No memory_config_id found for group_id: {group_id}. "
f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration."
)
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db(
connector: Neo4jConnector,
group_id: str,
end_user_id: str,
limit: int,
by_user: bool = False
) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args:
connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else:
query = (
"MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC "
"LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询
results = await connector.execute_query(
query,
id=group_id,
id=end_user_id,
limit=limit,
names_to_exclude=names_to_exclude
)
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询
by_user: 是否按user_id查询默认Falseend_user_id查询
Raises:
ValueError: 如果group_id未提供或为空
ValueError: 如果end_user_id未提供或为空
"""
# 验证group_id必须提供且不为空
if not group_id or not group_id.strip():
# 验证end_user_id必须提供且不为空
if not end_user_id or not end_user_id.strip():
raise ValueError(
"group_id is required. Please provide a valid group_id or user_id."
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
end_user_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date

View File

@@ -2,13 +2,16 @@ import os
import re
import glob
import json
from pathlib import Path
from typing import Tuple
try:
from app.core.memory.utils.config.definitions import PROJECT_ROOT
except Exception:
# Fallback: derive project root from this file location
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
# 需要向上 5 级到达 api/ 目录
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
def _get_latest_prompt_log_path() -> str | None:
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
triplet_relations_count = 0
temporal_count = 0
# Patterns
# 正则表达式模式 - 匹配当前日志格式
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
pat_triplet_done = re.compile(
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
pat_triplet_completed = re.compile(
r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
)
pat_temporal_done = re.compile(
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
pat_temporal_completed = re.compile(
r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
)
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f:
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
# 文本块数量(每个块触发一次陈述提取提示)
if pat_chunk_render.search(line):
chunk_count += 1
continue
m1 = pat_triplet_start.search(line)
if m1:
# 陈述数量(每个 Triplet Started 代表一个陈述被处理)
if pat_triplet_started.search(line):
statements_count += 1
continue
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
m_triplet = pat_triplet_completed.search(line)
if m_triplet:
try:
statements_count += int(m1.group(1))
triplet_relations_count += int(m_triplet.group(1))
triplet_entities_count += int(m_triplet.group(2))
except Exception:
pass
continue
m2 = pat_triplet_done.search(line)
if m2:
# 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
m_temporal = pat_temporal_completed.search(line)
if m_temporal:
try:
triplet_relations_count += int(m2.group(1))
triplet_entities_count += int(m2.group(2))
except Exception:
pass
continue
m3 = pat_temporal_done.search(line)
if m3:
try:
temporal_count += int(m3.group(1))
temporal_count += int(m_temporal.group(1))
except Exception:
pass
continue
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
def get_recent_activity_stats() -> Tuple[dict, str]:
"""Get aggregated stats from all prompt logs in logs/.
"""Get stats from the latest prompt log file only.
Returns (stats_dict, message).
"""
all_logs = _get_all_prompt_logs()
# Fallback to recursive search if none found in logs/
if not all_logs:
# 获取最新的日志文件
latest_log = _get_latest_prompt_log_path()
# 如果没有找到,尝试递归搜索
if not latest_log:
all_logs = _get_any_logs_recursive()
if not all_logs:
if all_logs:
latest_log = all_logs[-1] # 取最新的
if not latest_log:
return (
{
"chunk_count": 0,
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
"未找到日志文件,请确认已运行过提取流程。",
)
agg = {
"chunk_count": 0,
"statements_count": 0,
"triplet_entities_count": 0,
"triplet_relations_count": 0,
"temporal_count": 0,
}
for path in all_logs:
s = parse_stats_from_log(path)
agg["chunk_count"] += s.get("chunk_count", 0)
agg["statements_count"] += s.get("statements_count", 0)
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
agg["temporal_count"] += s.get("temporal_count", 0)
# Attach a summary of files combined
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
return agg, "成功汇总 logs 目录中所有提示日志。"
# 只解析最新的日志文件
stats = parse_stats_from_log(latest_log)
# 添加日志文件路径信息
stats["log_path"] = f"最新:{latest_log}"
return stats, "成功读取最近一次记忆活动统计。"
def _format_summary(stats: dict) -> str:

View File

@@ -1 +0,0 @@
"""Evaluation package with dataset-specific pipelines and a unified runner."""

View File

@@ -1,30 +0,0 @@
⏬数据集下载地址:
Locomo10.jsonhttps://github.com/snap-research/locomo/tree/main/data
LongMemEval_oracle.jsonhttps://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
上方数据集下载好后全部放入app/core/memory/data文件夹中
全流程基准测试运行:
locomo
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
LongMemEval
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
memsciqa
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
单独检索评估运行命令:
python -m app.core.memory.evaluation.locomo.locomo_test
python -m app.core.memory.evaluation.longmemeval.test_eval
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
需要先在项目中修改需要检测评估的group_id。
参数及解释:
● --dataset longmemeval - 指定数据集
● --sample-size 10 - 评估10个样本
● --start-index 0 - 从第0个样本开始
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
● --search-limit 8 - 检索限制8条
● --context-char-budget 4000 - 上下文字符预算4000
● --search-type hybrid - 使用混合检索
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
● --reset-group - 运行前清空组数据

View File

@@ -1,100 +0,0 @@
import math
import re
from typing import List, Dict
def _normalize(text: str) -> List[str]:
"""Lowercase, strip punctuation, and split into tokens."""
text = text.lower().strip()
# Python's re doesn't support \p classes; use a simple non-word filter
text = re.sub(r"[^\w\s]", " ", text)
tokens = [t for t in text.split() if t]
return tokens
def exact_match(pred: str, ref: str) -> float:
return float(_normalize(pred) == _normalize(ref))
def jaccard(pred: str, ref: str) -> float:
p = set(_normalize(pred))
r = set(_normalize(ref))
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
def f1_score(pred: str, ref: str) -> float:
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens:
return 0.0
# Clipped count
r_counts: Dict[str, int] = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts: Dict[str, int] = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
# Brevity penalty
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def percentile(values: List[float], p: float) -> float:
if not values:
return 0.0
vals = sorted(values)
k = (len(vals) - 1) * p
f = math.floor(k)
c = math.ceil(k)
if f == c:
return vals[int(k)]
return vals[f] + (k - f) * (vals[c] - vals[f])
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
if not latencies_ms:
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
p25 = percentile(latencies_ms, 0.25)
p50 = percentile(latencies_ms, 0.50)
p75 = percentile(latencies_ms, 0.75)
p95 = percentile(latencies_ms, 0.95)
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
def avg_context_tokens(contexts: List[str]) -> float:
if not contexts:
return 0.0
return sum(len(_normalize(c)) for c in contexts) / len(contexts)

View File

@@ -1,60 +0,0 @@
"""
Dialogue search queries for evaluation purposes.
This file contains Cypher queries for searching dialogues, entities, and chunks.
Placed in evaluation directory to avoid circular imports with src modules.
"""
# Entity search queries
SEARCH_ENTITIES_BY_NAME = """
MATCH (e:Entity)
WHERE e.name = $name
RETURN e
"""
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
MATCH (e:Entity)
WHERE e.name CONTAINS $name
RETURN e
"""
# Chunk search queries
SEARCH_CHUNKS_BY_CONTENT = """
MATCH (c:Chunk)
WHERE c.content CONTAINS $content
RETURN c
"""
# Dialogue search queries
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE d.dialog_id = $dialog_id
RETURN d
"""
SEARCH_DIALOGUES_BY_CONTENT = """
MATCH (d:Dialogue)
WHERE d.content CONTAINS $q
RETURN d
"""
DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -1,341 +0,0 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
group_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
save_chunk_output_path: str | None = None,
) -> bool:
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
This function mirrors the steps in main(), but starts from raw text contexts.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
Returns:
True if data saved successfully, False otherwise.
"""
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
# Initialize llm client with graceful fallback
llm_client = None
llm_available = True
try:
from app.core.memory.utils.config import definitions as config_defs
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False
# Step A: Build DialogData list from contexts with robust parsing
chunker = DialogueChunker(chunker_strategy)
dialog_data_list: List[DialogData] = []
for idx, ctx in enumerate(contexts):
messages: List[ConversationMessage] = []
# Improved parsing: capture multi-line message blocks, normalize roles
pattern = r"^\s*(用户|AI|assistant|user)\s*[:]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[:]|\Z)"
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
if matches:
for m in matches:
raw_role = m.group(1).strip()
content = m.group(2).strip()
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=content))
else:
# Fallback: line-by-line parsing
for raw in ctx.split("\n"):
line = raw.strip()
if not line:
continue
m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)$', line)
if m:
role = m.group(1).strip()
msg = m.group(2).strip()
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=msg))
else:
# Final fallback: treat as user message
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
messages.append(ConversationMessage(role=default_role, msg=line))
context_model = ConversationContext(msgs=messages)
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
group_id=group_id,
user_id="default_user",
apply_id="default_application",
)
# Generate chunks
dialog.chunks = await chunker.process_dialogue(dialog)
dialog_data_list.append(dialog)
if not dialog_data_list:
print("No dialogs to process for ingestion.")
return False
# Optionally save chunking outputs for debugging
if save_chunk_output:
try:
def _serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
from app.core.config import settings
settings.ensure_memory_output_dir()
default_path = settings.get_memory_output_path("chunker_test_output.txt")
out_path = save_chunk_output_path or default_path
combined_output = [dd.model_dump() for dd in dialog_data_list]
with open(out_path, "w", encoding="utf-8") as f:
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
print(f"Saved chunking results to: {out_path}")
except Exception as e:
print(f"Failed to save chunking results: {e}")
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
if not llm_available:
print("[Ingestion] Skipping extraction pipeline (no LLM).")
return False
# 初始化 embedder 客户端
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try:
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e:
print(f"[Ingestion] Failed to initialize embedder client: {e}")
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
return False
connector = Neo4jConnector()
# 初始化并运行 ExtractionOrchestrator
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=connector,
config=config,
)
# 创建一个包装的 orchestrator 来修复时间提取器的输出
# 保存原始的 _assign_extracted_data 方法
original_assign = orchestrator._assign_extracted_data
def clean_temporal_value(value):
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
if value is None:
return None
if isinstance(value, str):
# 处理字符串形式的 'null', 'None', 空字符串等
if value.lower() in ('null', 'none', '') or value.strip() == '':
return None
return value
async def patched_assign_extracted_data(*args, **kwargs):
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
result = await original_assign(*args, **kwargs)
# 清理返回的 dialog_data_list 中的 temporal_validity
for dialog in result:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
# 清理 valid_at 和 invalid_at
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return result
# 替换方法
orchestrator._assign_extracted_data = patched_assign_extracted_data
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
original_create = orchestrator._create_nodes_and_edges
async def patched_create_nodes_and_edges(dialog_data_list_arg):
"""包装方法:在创建节点前再次清理 temporal_validity"""
# 最后一次清理,确保万无一失
for dialog in dialog_data_list_arg:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return await original_create(dialog_data_list_arg)
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
# 运行完整的提取流水线
# orchestrator.run 返回 7 个元素的元组
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
) = result
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
# Step G: 生成记忆摘要
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
summaries = []
# Step H: Save to Neo4j
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
entity_edges=entity_entity_edges,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
connector=connector
)
# Save memory summaries separately
if summaries:
try:
await add_memory_summary_nodes(summaries, connector)
await add_memory_summary_statement_edges(summaries, connector)
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
except Exception as e:
print(f"Warning: Failed to save summary nodes: {e}")
await connector.close()
if success:
print("Successfully saved extracted data to Neo4j!")
else:
print("Failed to save data to Neo4j")
return success
except Exception as e:
print(f"Failed to save data to Neo4j: {e}")
return False
async def handle_context_processing(args):
"""Handle context-based processing from command line arguments."""
contexts = []
if args.contexts:
contexts.extend(args.contexts)
if args.context_file:
try:
with open(args.context_file, 'r', encoding='utf-8') as f:
contexts.extend(line.strip() for line in f if line.strip())
except Exception as e:
print(f"Error reading context file: {e}")
return False
if not contexts:
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_group_id)
async def main_from_contexts(contexts: List[str], group_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
group_id=group_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True
)
if success:
print("Successfully processed and saved contexts to Neo4j!")
else:
print("Failed to process contexts.")
return success

View File

@@ -1,575 +0,0 @@
"""
LoCoMo Benchmark Script
This module provides the main entry point for running LoCoMo benchmark evaluations.
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
in a clean, maintainable way.
Usage:
python locomo_benchmark.py --sample_size 20 --search_type hybrid
"""
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
try:
from dotenv import load_dotenv
except ImportError:
def load_dotenv():
pass
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
f1_score,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score,
locomo_multi_f1,
)
from app.core.memory.evaluation.locomo.locomo_utils import (
extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references,
retrieve_relevant_information,
select_and_format_information,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
sample_size: int = 20,
group_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
reset_group: bool = False,
skip_ingest: bool = False,
output_dir: Optional[str] = None
) -> Dict[str, Any]:
"""
Run LoCoMo benchmark evaluation.
This function orchestrates the complete evaluation pipeline:
1. Load LoCoMo dataset (only QA pairs from first conversation)
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
3. For each question:
- Retrieve relevant information
- Generate answer using LLM
- Calculate metrics
4. Aggregate results and save to file
Note: By default, only the first conversation is ingested into the database,
and only QA pairs from that conversation are evaluated. This ensures that
all questions have corresponding memory in the database for retrieval.
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
reset_group: Whether to clear and re-ingest data (not implemented)
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
output_dir: Directory to save results (uses default if None)
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default group_id if not provided
group_id = group_id or SELECTED_GROUP_ID
# Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
# Fallback to current directory
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
print(f"\n{'='*60}")
print("🚀 Starting LoCoMo Benchmark Evaluation")
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
print(f" Data path: {data_path}")
print(f"{'='*60}\n")
# Step 1: Load LoCoMo data
print("📂 Loading LoCoMo dataset...")
try:
# Only load QA pairs from the first conversation (index 0)
# since we only ingest the first conversation into the database
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
except Exception as e:
print(f"❌ Failed to load data: {e}")
return {
"error": f"Data loading failed: {e}",
"timestamp": datetime.now().isoformat()
}
# Step 2: Extract conversations and ingest if needed
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n")
else:
print("💾 Checking database ingestion...")
try:
conversations = extract_conversations(data_path, max_dialogues=1)
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
group_id=group_id,
reset=reset_group
)
if success:
print("✅ Ingestion completed successfully\n")
else:
print("⚠️ Ingestion may have failed, continuing anyway\n")
except Exception as e:
print(f"❌ Ingestion failed: {e}")
print("⚠️ Continuing with evaluation (database may be empty)\n")
# Step 3: Initialize clients
print("🔧 Initializing clients...")
connector = Neo4jConnector()
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
print("✅ Clients initialized\n")
# Step 4: Process questions
print(f"🔍 Processing {len(qa_items)} questions...")
print(f"{'='*60}\n")
# Tracking variables
latencies_search: List[float] = []
latencies_llm: List[float] = []
context_counts: List[int] = []
context_chars: List[int] = []
context_tokens: List[int] = []
# Metric lists
f1_scores: List[float] = []
bleu1_scores: List[float] = []
jaccard_scores: List[float] = []
locomo_f1_scores: List[float] = []
# Per-category tracking
category_counts: Dict[str, int] = {}
category_f1: Dict[str, List[float]] = {}
category_bleu1: Dict[str, List[float]] = {}
category_jaccard: Dict[str, List[float]] = {}
category_locomo_f1: Dict[str, List[float]] = {}
# Detailed samples
samples: List[Dict[str, Any]] = []
# Fixed anchor date for temporal resolution
anchor_date = datetime(2023, 5, 8)
try:
for idx, item in enumerate(qa_items, 1):
question = item.get("question", "")
ground_truth = item.get("answer", "")
category = get_category_name(item)
# Ensure ground truth is a string
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
print(f"[{idx}/{len(qa_items)}] Category: {category}")
print(f"❓ Question: {question}")
print(f"✅ Ground Truth: {ground_truth_str}")
# Step 4a: Retrieve relevant information
t_search_start = time.time()
try:
retrieved_info = await retrieve_relevant_information(
question=question,
group_id=group_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
embedder=embedder
)
t_search_end = time.time()
search_latency = (t_search_end - t_search_start) * 1000
latencies_search.append(search_latency)
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
except Exception as e:
print(f"❌ Retrieval failed: {e}")
retrieved_info = []
search_latency = 0.0
latencies_search.append(search_latency)
# Step 4b: Select and format context
context_text = select_and_format_information(
retrieved_info=retrieved_info,
question=question,
max_chars=context_char_budget
)
# Resolve temporal references
context_text = resolve_temporal_references(context_text, anchor_date)
# Add reference date to context
if context_text:
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
else:
context_text = "No relevant context found."
# Track context statistics
context_counts.append(len(retrieved_info))
context_chars.append(len(context_text))
context_tokens.append(len(context_text.split()))
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
# Step 4c: Generate answer with LLM
messages = [
{
"role": "system",
"content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)
},
{
"role": "user",
"content": f"Question: {question}\n\nContext:\n{context_text}"
}
]
t_llm_start = time.time()
try:
response = await llm_client.chat(messages=messages)
t_llm_end = time.time()
llm_latency = (t_llm_end - t_llm_start) * 1000
latencies_llm.append(llm_latency)
# Extract prediction from response
if hasattr(response, 'content'):
prediction = response.content.strip()
elif isinstance(response, dict):
prediction = response["choices"][0]["message"]["content"].strip()
else:
prediction = "Unknown"
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
except Exception as e:
print(f"❌ LLM failed: {e}")
prediction = "Unknown"
llm_latency = 0.0
latencies_llm.append(llm_latency)
# Step 4d: Calculate metrics
f1_val = f1_score(prediction, ground_truth_str)
bleu1_val = bleu1(prediction, ground_truth_str)
jaccard_val = jaccard(prediction, ground_truth_str)
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
if item.get("category") == 1:
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
else:
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
# Accumulate metrics
f1_scores.append(f1_val)
bleu1_scores.append(bleu1_val)
jaccard_scores.append(jaccard_val)
locomo_f1_scores.append(locomo_f1_val)
# Track by category
category_counts[category] = category_counts.get(category, 0) + 1
category_f1.setdefault(category, []).append(f1_val)
category_bleu1.setdefault(category, []).append(bleu1_val)
category_jaccard.setdefault(category, []).append(jaccard_val)
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
print()
# Save sample details
samples.append({
"question": question,
"ground_truth": ground_truth_str,
"prediction": prediction,
"category": category,
"metrics": {
"f1": f1_val,
"bleu1": bleu1_val,
"jaccard": jaccard_val,
"locomo_f1": locomo_f1_val
},
"retrieval": {
"num_docs": len(retrieved_info),
"context_length": len(context_text)
},
"timing": {
"search_ms": search_latency,
"llm_ms": llm_latency
}
})
finally:
# Close connector
await connector.close()
# Step 5: Aggregate results
print(f"\n{'='*60}")
print("📊 Aggregating Results")
print(f"{'='*60}\n")
# Overall metrics
overall_metrics = {
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
}
# Per-category metrics
by_category: Dict[str, Dict[str, Any]] = {}
for cat in category_counts:
f1_list = category_f1.get(cat, [])
b1_list = category_bleu1.get(cat, [])
j_list = category_jaccard.get(cat, [])
lf_list = category_locomo_f1.get(cat, [])
by_category[cat] = {
"count": category_counts[cat],
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
}
# Latency statistics
latency = {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm)
}
# Context statistics
context_stats = {
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
}
# Build result dictionary
result = {
"dataset": "locomo",
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"group_id": group_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_id": SELECTED_LLM_ID,
"embedding_id": SELECTED_EMBEDDING_ID
},
"overall_metrics": overall_metrics,
"by_category": by_category,
"latency": latency,
"context_stats": context_stats,
"samples": samples
}
# Step 6: Save results
if output_dir is None:
output_dir = os.path.join(
os.path.dirname(__file__),
"results"
)
os.makedirs(output_dir, exist_ok=True)
# Generate timestamped filename
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Results saved to: {output_path}\n")
except Exception as e:
print(f"❌ Failed to save results: {e}")
print("📊 Printing results to console instead:\n")
print(json.dumps(result, ensure_ascii=False, indent=2))
return result
def main():
"""
Parse command-line arguments and run benchmark.
This function provides a CLI interface for running LoCoMo benchmarks
with configurable parameters.
"""
parser = argparse.ArgumentParser(
description="Run LoCoMo benchmark evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sample_size",
type=int,
default=20,
help="Number of QA pairs to evaluate"
)
parser.add_argument(
"--group_id",
type=str,
default=None,
help="Database group ID for retrieval (uses default if not specified)"
)
parser.add_argument(
"--search_type",
type=str,
default="hybrid",
choices=["keyword", "embedding", "hybrid"],
help="Search strategy to use"
)
parser.add_argument(
"--search_limit",
type=int,
default=12,
help="Maximum number of documents to retrieve per query"
)
parser.add_argument(
"--context_char_budget",
type=int,
default=8000,
help="Maximum characters for context"
)
parser.add_argument(
"--reset_group",
action="store_true",
help="Clear and re-ingest data (not implemented)"
)
parser.add_argument(
"--skip_ingest",
action="store_true",
help="Skip data ingestion and use existing data in Neo4j"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to save results (uses default if not specified)"
)
args = parser.parse_args()
# Load environment variables
load_dotenv()
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
group_id=args.group_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
reset_group=args.reset_group,
skip_ingest=args.skip_ingest,
output_dir=args.output_dir
))
# Print summary
print(f"\n{'='*60}")
# Check if there was an error
if 'error' in result:
print("❌ Benchmark Failed!")
print(f"{'='*60}")
print(f"Error: {result['error']}")
return
print("🎉 Benchmark Complete!")
print(f"{'='*60}")
print("📊 Final Results:")
print(f" Sample size: {result.get('sample_size', 0)}")
print(f" F1: {result['overall_metrics']['f1']:.3f}")
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
if result.get('context_stats'):
print("\n📈 Context Statistics:")
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
if result.get('latency'):
print("\n⏱️ Latency Statistics:")
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
f"P50: {result['latency']['search']['p50']:.1f}ms, "
f"P95: {result['latency']['search']['p95']:.1f}ms")
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
f"P95: {result['latency']['llm']['p95']:.1f}ms")
if result.get('by_category'):
print("\n📂 Results by Category:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" Count: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
print(f" Jaccard: {metrics['jaccard']:.3f}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

View File

@@ -1,225 +0,0 @@
"""
LoCoMo-specific metric calculations.
This module provides clean, simplified implementations of metrics used for
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
"""
import re
from typing import Dict, Any
def normalize_text(text: str) -> str:
"""
Normalize text for LoCoMo evaluation.
Normalization steps:
- Convert to lowercase
- Remove commas
- Remove stop words (a, an, the, and)
- Remove punctuation
- Normalize whitespace
Args:
text: Input text to normalize
Returns:
Normalized text string with consistent formatting
Examples:
>>> normalize_text("The cat, and the dog")
'cat dog'
>>> normalize_text("Hello, World!")
'hello world'
"""
# Ensure input is a string
text = str(text) if text is not None else ""
# Convert to lowercase
text = text.lower()
# Remove commas
text = re.sub(r"[\,]", " ", text)
# Remove stop words
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
# Remove punctuation (keep only word characters and whitespace)
text = re.sub(r"[^\w\s]", " ", text)
# Normalize whitespace (collapse multiple spaces to single space)
text = " ".join(text.split())
return text
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for single-answer questions.
Uses token-level precision and recall based on normalized text.
Treats tokens as sets (no duplicate counting).
Args:
prediction: Model's predicted answer
ground_truth: Correct answer
Returns:
F1 score between 0.0 and 1.0
Examples:
>>> locomo_f1_score("Paris", "Paris")
1.0
>>> locomo_f1_score("The cat", "cat")
1.0
>>> locomo_f1_score("dog", "cat")
0.0
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Normalize and tokenize
pred_tokens = normalize_text(pred_str).split()
truth_tokens = normalize_text(truth_str).split()
# Handle empty cases
if not pred_tokens or not truth_tokens:
return 0.0
# Convert to sets for comparison
pred_set = set(pred_tokens)
truth_set = set(truth_tokens)
# Calculate true positives (intersection)
true_positives = len(pred_set & truth_set)
# Calculate precision and recall
precision = true_positives / len(pred_set) if pred_set else 0.0
recall = true_positives / len(truth_set) if truth_set else 0.0
# Calculate F1 score
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for multi-answer questions.
Handles comma-separated answers by:
1. Splitting both prediction and ground truth by commas
2. For each ground truth answer, finding the best matching prediction
3. Averaging the F1 scores across all ground truth answers
Args:
prediction: Model's predicted answer (may contain multiple comma-separated answers)
ground_truth: Correct answer (may contain multiple comma-separated answers)
Returns:
Average F1 score across all ground truth answers (0.0 to 1.0)
Examples:
>>> locomo_multi_f1("Paris, London", "Paris, London")
1.0
>>> locomo_multi_f1("Paris", "Paris, London")
0.5
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
0.5
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Split by commas and strip whitespace
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
# Handle empty cases
if not predictions or not ground_truths:
return 0.0
# For each ground truth, find the best matching prediction
f1_scores = []
for gt in ground_truths:
# Calculate F1 with each prediction and take the maximum
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
f1_scores.append(best_f1)
# Return average F1 across all ground truths
return sum(f1_scores) / len(f1_scores)
def get_category_name(item: Dict[str, Any]) -> str:
"""
Extract and normalize category name from QA item.
Handles both numeric categories (1-4) and string categories with various formats.
Supports multiple field names: "cat", "category", "type".
Category mapping:
- 1 or "multi-hop" -> "Multi-Hop"
- 2 or "temporal" -> "Temporal"
- 3 or "open domain" -> "Open Domain"
- 4 or "single-hop" -> "Single-Hop"
Args:
item: QA item dictionary containing category information
Returns:
Standardized category name or "unknown" if not found
Examples:
>>> get_category_name({"category": 1})
'Multi-Hop'
>>> get_category_name({"cat": "temporal"})
'Temporal'
>>> get_category_name({"type": "Single-Hop"})
'Single-Hop'
"""
# Numeric category mapping
CATEGORY_MAP = {
1: "Multi-Hop",
2: "Temporal",
3: "Open Domain",
4: "Single-Hop",
}
# String category aliases (case-insensitive)
TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
# Try "cat" field first (string category)
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return TYPE_ALIASES.get(lower, name)
# Try "category" field (can be int or string)
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP.get(cat_num, "unknown")
elif isinstance(cat_num, str) and cat_num.strip():
lower = cat_num.strip().lower()
return TYPE_ALIASES.get(lower, cat_num.strip())
# Try "type" field as fallback
cat_type = item.get("type")
if isinstance(cat_type, str) and cat_type.strip():
lower = cat_type.strip().lower()
return TYPE_ALIASES.get(lower, cat_type.strip())
return "unknown"

View File

@@ -1,810 +0,0 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import json
import math
import os
import re
import sys
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
from dotenv import load_dotenv
# 1
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
src_dir = os.path.join(project_root, "src")
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
load_dotenv()
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
def _loc_normalize(text: str) -> str:
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text)
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 尝试从 metrics.py 导入基础指标
try:
from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
# 回退到本地实现
def f1_score(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens:
return 0.0
r_counts = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def jaccard(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p = set(_loc_normalize(pred_str).split())
r = set(_loc_normalize(ref_str).split())
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
try:
# 添加 evaluation 目录路径
evaluation_dir = os.path.join(project_root, "evaluation")
if evaluation_dir not in sys.path:
sys.path.insert(0, evaluation_dir)
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
# 回退到本地实现 LoCoMo 特定函数
def _resolve_relative_times(text: str, anchor: datetime) -> str:
t = str(text) if text is not None else ""
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
p_tokens = _loc_normalize(prediction).split()
g_tokens = _loc_normalize(ground_truth).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
"""根据问题复杂度和进度动态调整检索参数"""
# 分析问题复杂度
word_count = len(question.split())
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
# 根据进度调整 - 后期问题可能需要更精确的检索
progress_factor = question_index / total_questions
base_limit = 12
if has_temporal and has_multi_hop:
base_limit = 20
elif word_count > 8:
base_limit = 16
# 随着测试进行,逐渐收紧检索范围
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
# 动态调整最大字符数
max_chars = 8000 + 4000 * (1 - progress_factor)
return {
"limit": adjusted_limit,
"max_chars": int(max_chars)
}
class EnhancedEvaluationMonitor:
def __init__(self, reset_interval=5, performance_threshold=0.6):
self.question_count = 0
self.reset_interval = reset_interval
self.performance_threshold = performance_threshold
self.consecutive_low_scores = 0
self.performance_history = []
self.recent_f1_scores = []
def should_reset_connections(self, current_f1=None):
"""基于计数和性能双重判断"""
# 定期重置
if self.question_count % self.reset_interval == 0:
return True
# 性能驱动的重置
if current_f1 is not None and current_f1 < self.performance_threshold:
self.consecutive_low_scores += 1
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
print("🚨 连续低分,触发紧急重置")
self.consecutive_low_scores = 0
return True
else:
self.consecutive_low_scores = 0
return False
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
"""记录性能指标,检测衰减"""
self.performance_history.append({
'index': question_index,
'metrics': metrics,
'context_length': context_length,
'retrieved_docs': retrieved_docs,
'timestamp': time.time()
})
# 记录最近的F1分数
self.recent_f1_scores.append(metrics['f1'])
if len(self.recent_f1_scores) > 5:
self.recent_f1_scores.pop(0)
def get_recent_performance(self):
"""获取近期平均性能"""
if not self.recent_f1_scores:
return 0.5
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
def get_performance_trend(self):
"""分析性能趋势"""
if len(self.performance_history) < 2:
return "stable"
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
return "stable"
recent_avg = sum(recent_metrics) / len(recent_metrics)
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
if recent_avg < earlier_avg * 0.8:
return "degrading"
elif recent_avg > earlier_avg * 1.1:
return "improving"
else:
return "stable"
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
"""基于问题复杂度和近期性能动态调整检索参数"""
# 基础参数
base_params = get_dynamic_search_params(question, question_index, total_questions)
# 性能自适应调整
if recent_performance < 0.5: # 近期表现差
# 增加检索范围,尝试获取更多上下文
base_params["limit"] = min(base_params["limit"] + 5, 25)
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
elif recent_performance > 0.8: # 近期表现好
# 收紧检索,提高精度
base_params["limit"] = max(base_params["limit"] - 2, 8)
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
# 中间阶段特殊处理
mid_sequence_factor = abs(question_index / total_questions - 0.5)
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用更精确的检索策略")
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
return base_params
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
"""考虑问题序列位置的智能选择"""
if not contexts:
return ""
# 在序列中间阶段使用更严格的筛选
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用严格上下文筛选")
# 提取问题关键词
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
# 只保留高度相关的上下文
filtered_contexts = []
for context in contexts:
context_lower = context.lower()
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
if any(char.isdigit() for char in context):
relevance_score += 2
# 提高阈值:只有得分>=3的上下文才保留
if relevance_score >= 3:
filtered_contexts.append(context)
else:
print(f" - 过滤低分上下文: 得分={relevance_score}")
contexts = filtered_contexts
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
# 使用原有的智能选择逻辑
return smart_context_selection(contexts, question, max_chars)
async def run_enhanced_evaluation():
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 修正导入路径:使用 app.core.memory.src 前缀
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据
# 获取项目根目录
current_file = os.path.abspath(__file__)
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
memory_dir = os.path.dirname(evaluation_dir) # memory目录
data_path = os.path.join(memory_dir, "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
qa_items = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items = qa_items[:20] # 测试多少个问题
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 初始化连接器
connector = Neo4jConnector()
# 初始化结果字典
results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
total_f1 = 0.0
total_bleu1 = 0.0
total_jaccard = 0.0
total_loc_f1 = 0.0
total_context_length = 0
total_retrieved_docs = 0
category_stats = {}
try:
for i, item in enumerate(items):
monitor.question_count += 1
# 获取近期性能用于重置判断
recent_performance = monitor.get_recent_performance()
# 增强的重置判断
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
if should_reset and i > 0:
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
await connector.close()
connector = Neo4jConnector() # 创建新连接
print("✅ 连接重置完成")
q = item.get("question", "")
ref = item.get("answer", "")
ref_str = str(ref) if ref is not None else ""
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
print(f"✅ 真实答案: {ref_str}")
# 分类别统计
category = "Unknown"
if item.get("category") == 1:
category = "Multi-Hop"
elif item.get("category") == 2:
category = "Temporal"
elif item.get("category") == 3:
category = "Open Domain"
elif item.get("category") == 4:
category = "Single-Hop"
# 增强的检索参数
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
search_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
# 使用项目标准的混合检索方法
t0 = time.time()
contexts_all = []
try:
# 使用统一的搜索服务
from app.core.memory.storage_services.search import run_hybrid_search
print("🔀 使用混合搜索服务...")
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
group_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重
embedding_id=SELECTED_EMBEDDING_ID
)
# 处理搜索结果 - 新的搜索服务返回统一的结构
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
print(f"📊 有效上下文数量: {len(contexts_all)}")
except Exception as e:
print(f"❌ 检索失败: {e}")
contexts_all = []
t1 = time.time()
search_time = (t1 - t0) * 1000
# 增强的上下文选择
context_text = ""
if contexts_all:
# 使用增强的上下文选择
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览(不只是第一条)
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
# 🔍 调试:检查答案是否在上下文中
if ref_str and ref_str.strip():
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# LLM 回答
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
except Exception as e:
print(f"❌ LLM 生成失败: {e}")
pred = "Unknown"
t3 = time.time()
llm_time = (t3 - t2) * 1000
# 计算指标 - 使用导入的指标函数
f1_val = f1_score(pred, ref_str)
bleu1_val = bleu1(pred, ref_str)
jaccard_val = jaccard(pred, ref_str)
loc_f1_val = loc_f1_score(pred, ref_str)
print(f"🤖 LLM 回答: {pred}")
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
# 更新统计
total_f1 += f1_val
total_bleu1 += bleu1_val
total_jaccard += jaccard_val
total_loc_f1 += loc_f1_val
total_context_length += len(context_text)
total_retrieved_docs += len(contexts_all)
if category not in category_stats:
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
category_stats[category]["count"] += 1
category_stats[category]["f1_sum"] += f1_val
category_stats[category]["b1_sum"] += bleu1_val
category_stats[category]["j_sum"] += jaccard_val
category_stats[category]["loc_f1_sum"] += loc_f1_val
# 记录性能指标
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
# 保存结果
question_result = {
"question": q,
"ground_truth": ref_str,
"prediction": pred,
"category": category,
"metrics": metrics,
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": max_chars,
"recent_performance": recent_performance
},
"timing": {
"search_ms": search_time,
"llm_ms": llm_time
}
}
results["questions"].append(question_result)
print("="*60)
except Exception as e:
print(f"❌ 评估过程中发生错误: {e}")
# 即使出错,也返回已有的结果
import traceback
traceback.print_exc()
finally:
await connector.close()
# 计算总体指标
n = len(items)
if n > 0:
results["overall_metrics"] = {
"f1": total_f1 / n,
"b1": total_bleu1 / n,
"j": total_jaccard / n,
"loc_f1": total_loc_f1 / n
}
for category, stats in category_stats.items():
count = stats["count"]
results["category_metrics"][category] = {
"count": count,
"f1": stats["f1_sum"] / count,
"bleu1": stats["b1_sum"] / count,
"jaccard": stats["j_sum"] / count,
"loc_f1": stats["loc_f1_sum"] / count
}
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
# 分析性能趋势
results["performance_trend"] = monitor.get_performance_trend()
results["reset_interval"] = monitor.reset_interval
results["total_questions_processed"] = monitor.question_count
return results
if __name__ == "__main__":
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
print("📋 增强特性:")
print(" - 双重重置策略:定期重置 + 性能驱动重置")
print(" - 动态检索参数:基于近期性能自适应调整")
print(" - 中间阶段严格筛选:提高上下文质量要求")
print(" - 连续性能监控:实时检测性能衰减")
result = asyncio.run(run_enhanced_evaluation())
print("\n📊 最终评估结果:")
print("总体指标:")
print(f" F1: {result['overall_metrics']['f1']:.4f}")
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
print("\n分类别指标:")
for category, metrics in result['category_metrics'].items():
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
print("\n检索统计:")
stats = result['retrieval_stats']
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
print(f"\n性能趋势: {result['performance_trend']}")
print(f"重置间隔: 每{result['reset_interval']}个问题")
print(f"处理问题总数: {result['total_questions_processed']}")
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
# 保存结果到指定目录
# 使用代码文件所在目录的绝对路径
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_file_dir, "results")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已保存到: {output_file}")

View File

@@ -1,626 +0,0 @@
"""
LoCoMo Utilities Module
This module provides helper functions for the LoCoMo benchmark evaluation:
- Data loading from JSON files
- Conversation extraction for ingestion
- Temporal reference resolution
- Context selection and formatting
- Retrieval wrapper functions
- Ingestion wrapper functions
"""
import os
import json
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from app.core.memory.utils.definitions import PROJECT_ROOT
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
def load_locomo_data(
data_path: str,
sample_size: int,
conversation_index: int = 0
) -> List[Dict[str, Any]]:
"""
Load LoCoMo dataset from JSON file.
The LoCoMo dataset structure is a list of conversation objects, where each
object contains a "qa" list of question-answer pairs.
Args:
data_path: Path to locomo10.json file
sample_size: Number of QA pairs to load (limits total QA items returned)
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
Returns:
List of QA item dictionaries, each containing:
- question: str
- answer: str
- category: int (1-4)
- evidence: List[str]
Raises:
FileNotFoundError: If data_path does not exist
json.JSONDecodeError: If file is not valid JSON
IndexError: If conversation_index is out of range
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo data structure: list of objects, each with a "qa" list
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
# Only load QA pairs from the specified conversation
if conversation_index < len(raw):
entry = raw[conversation_index]
if isinstance(entry, dict) and "qa" in entry:
qa_items.extend(entry.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has {len(raw)} conversations."
)
else:
# Fallback: single object with qa list
if conversation_index == 0:
qa_items.extend(raw.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has only 1 conversation."
)
# Return only the requested sample size
return qa_items[:sample_size]
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
"""
Extract conversation texts from LoCoMo data for ingestion.
This function extracts the raw conversation dialogues from the LoCoMo dataset
so they can be ingested into the memory system. Each conversation is formatted
as a multi-line string with "role: message" format.
Args:
data_path: Path to locomo10.json file
max_dialogues: Maximum number of dialogues to extract (default: 1)
Returns:
List of conversation strings formatted for ingestion.
Each string contains multiple lines in format "role: message"
Example output:
[
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
]
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# Ensure we have a list of entries
entries = raw if isinstance(raw, list) else [raw]
contents: List[str] = []
for i, entry in enumerate(entries[:max_dialogues]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
if not isinstance(conv, dict):
continue
lines: List[str] = []
# Collect all session_* messages
for key, val in sorted(conv.items()):
if isinstance(val, list) and key.startswith("session_"):
for msg in val:
if not isinstance(msg, dict):
continue
role = msg.get("speaker") or "User"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
if lines:
contents.append("\n".join(lines))
return contents
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
"""
Resolve relative temporal references to absolute dates.
This function converts relative time expressions (like "today", "yesterday",
"3 days ago") into absolute ISO date strings based on an anchor date.
Supported patterns:
- today, yesterday, tomorrow
- X days ago, in X days
- last week, next week
Args:
text: Text containing temporal references
anchor_date: Reference date for resolution (datetime object)
Returns:
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
Example:
>>> anchor = datetime(2023, 5, 8)
>>> resolve_temporal_references("I saw him yesterday", anchor)
"I saw him 2023-05-07"
"""
# Ensure input is a string
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(
r"\btoday\b",
anchor_date.date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\byesterday\b",
(anchor_date - timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\btomorrow\b",
(anchor_date + timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
# X days ago
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date - timedelta(days=n)).date().isoformat()
# in X days
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date + timedelta(days=n)).date().isoformat()
t = re.sub(
r"\b(\d+)\s+days?\s+ago\b",
_ago_repl,
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bin\s+(\d+)\s+days?\b",
_in_repl,
t,
flags=re.IGNORECASE
)
# last week / next week (approximate as 7 days)
t = re.sub(
r"\blast\s+week\b",
(anchor_date - timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bnext\s+week\b",
(anchor_date + timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
return t
def select_and_format_information(
retrieved_info: List[str],
question: str,
max_chars: int = 8000
) -> str:
"""
Intelligently select and format most relevant retrieved information for LLM prompt.
This function scores each piece of retrieved information based on keyword matching
with the question, then selects the highest-scoring pieces up to the character limit.
Scoring criteria:
- Keyword matches (higher weight for multiple occurrences)
- Context length (moderate length preferred)
- Position (earlier contexts get bonus points)
Args:
retrieved_info: List of retrieved information strings (chunks, statements, entities)
question: Question being answered
max_chars: Maximum total characters to include in final prompt
Returns:
Formatted string combining the most relevant information for LLM prompt.
Contexts are separated by double newlines.
Example:
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
>>> question = "Where did Alice go?"
>>> select_and_format_information(contexts, question, max_chars=100)
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
"""
if not retrieved_info:
return ""
# Extract question keywords (filter out stop words and short words)
question_lower = question.lower()
stop_words = {
'what', 'when', 'where', 'who', 'why', 'how',
'did', 'do', 'does', 'is', 'are', 'was', 'were',
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {
word for word in question_words
if word not in stop_words and len(word) > 2
}
# Score each context
scored_contexts = []
for i, context in enumerate(retrieved_info):
context_lower = context.lower()
score = 0
# Keyword matching score
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# Multiple occurrences increase score
score += context_lower.count(word) * 2
# Length score (prefer moderate length)
context_len = len(context)
if 100 < context_len < 2000:
score += 5
elif context_len >= 2000:
score += 2
# Position bonus (earlier contexts often more relevant)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# Sort by score (descending)
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# Select contexts up to character limit
selected = []
total_chars = 0
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
else:
# Try to include high-scoring context by truncating
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# Find lines with keywords
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
truncated = '\n'.join(relevant_lines)
selected.append(truncated + "\n[Content truncated...]")
total_chars += len(truncated)
break
return "\n\n".join(selected)
async def retrieve_relevant_information(
question: str,
group_id: str,
search_type: str,
search_limit: int,
connector: Any,
embedder: Any
) -> List[str]:
"""
Retrieve relevant information from memory graph for a question.
This function searches the Neo4j memory graph (populated during ingestion) and
returns relevant chunks, statements, and entity information that might help
answer the question.
The function supports three search types:
- "keyword": Full-text search using Cypher queries
- "embedding": Vector similarity search using embeddings
- "hybrid": Combination of keyword and embedding search with reranking
Args:
question: Question to search for
group_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
embedder: Embedder client instance
Returns:
List of text strings (chunks, statements, entity summaries) from memory graph.
Each string represents a piece of retrieved information.
Raises:
Exception: If search fails (caught and returns empty list)
"""
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_embedding
)
from app.core.memory.storage_services.search import run_hybrid_search
contexts_all: List[str] = []
try:
if search_type == "embedding":
# Embedding-based search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context from chunks
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add summaries
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities (limit to 3 to avoid noise)
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# Keyword-based search
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
limit=search_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
# Build context from dialogues
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add entity names
if entities:
entity_names = [
str(e.get("name", "")).strip()
for e in entities[:5]
if e.get("name")
]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# Hybrid search with fallback to embedding
try:
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# Handle flat structure (new API format)
if search_results and isinstance(search_results, dict):
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Check if we got results
if not (chunks or statements or entities or summaries):
# Try nested structure (backward compatibility)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
raise ValueError("Hybrid search returned empty results")
else:
raise ValueError("Hybrid search returned empty results")
except Exception as e:
# Fallback to embedding search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context (same for both hybrid and fallback)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
except Exception as e:
# Return empty list on error
contexts_all = []
return contexts_all
async def ingest_conversations_if_needed(
conversations: List[str],
group_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
group_id=group_id,
save_chunk_output=True
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False

View File

@@ -1,878 +0,0 @@
import argparse
import asyncio
import json
import os
import statistics
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
import re
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
def _loc_normalize(text: str) -> str:
import re
# 确保输入是字符串
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text) # 去掉逗号
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 追加相对时间归一化为绝对日期有限支持today/yesterday/tomorrow/X days ago/in X days/last week/next week
def _resolve_relative_times(text: str, anchor: datetime) -> str:
import re
# 确保输入是字符串
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
# X days ago / in X days
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
# last week / next week以7天近似
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
# 单答案 F1按词集合计算近似原始实现去除词干依赖
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
p_tokens = _loc_normalize(pred_str).split()
g_tokens = _loc_normalize(truth_str).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
# 多答案 F1prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
CATEGORY_MAP_NUM_TO_NAME = {
4: "Single-Hop",
1: "Multi-Hop",
3: "Open Domain",
2: "Temporal",
}
_TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
def get_category_label(item: Dict[str, Any]) -> str:
# 1) 直接用字符串 cat
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return _TYPE_ALIASES.get(lower, name)
# 2) 数字 category 转名称
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
# 3) 备用 type 字段
t = item.get("type")
if isinstance(t, str) and t.strip():
lower = t.strip().lower()
return _TYPE_ALIASES.get(lower, t.strip())
return "unknown"
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_search_params_by_category(category: str):
"""根据问题类别调整检索参数"""
params_map = {
"Multi-Hop": {"limit": 20, "max_chars": 15000},
"Temporal": {"limit": 16, "max_chars": 10000},
"Open Domain": {"limit": 24, "max_chars": 18000},
"Single-Hop": {"limit": 12, "max_chars": 8000},
}
return params_map.get(category, {"limit": 16, "max_chars": 12000})
async def run_locomo_eval(
sample_size: int = 1,
group_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
llm_max_tokens: int = 32,
search_type: str = "hybrid", # 保持默认值不变
output_path: str | None = None,
skip_ingest_if_exists: bool = True,
llm_timeout: float = 10.0,
llm_max_retries: int = 1
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items: List[Dict[str, Any]] = qa_items[:sample_size]
# === 保持原来的数据摄入逻辑 ===
entries = raw if isinstance(raw, list) else [raw]
# 只摄入前1条对话保持原样
max_dialogues_to_ingest = 1
contents: List[str] = []
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest}")
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
sample_id = entry.get("sample_id", f"unknown_{i}")
print(f"🔍 处理对话 {i+1}: {sample_id}")
lines: List[str] = []
if isinstance(conv, dict):
# 收集所有 session_* 的消息
session_count = 0
for key, val in conv.items():
if isinstance(val, list) and key.startswith("session_"):
session_count += 1
for msg in val:
role = msg.get("speaker") or "用户"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
if not lines:
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
continue
contents.append("\n".join(lines))
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
# 选择要评测的QA对从所有对话中选取
indexed_items: List[tuple[int, Dict[str, Any]]] = []
if isinstance(raw, list):
for e_idx, entry in enumerate(raw):
for qa in entry.get("qa", []):
indexed_items.append((e_idx, qa))
else:
for qa in raw.get("qa", []):
indexed_items.append((0, qa))
# 这里使用sample_size来限制评测的QA数量
selected = indexed_items[:sample_size]
items: List[Dict[str, Any]] = [qa for _, qa in selected]
print(f"🎯 将评测 {len(items)} 个QA对数据库中只包含 {len(contents)} 个对话")
# === 修改结束 ===
connector = Neo4jConnector()
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
# 使用异步LLM客户端
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# connector initialized above
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 上下文诊断收集
per_query_context_counts: List[int] = []
per_query_context_avg_tokens: List[float] = []
per_query_context_chars: List[int] = []
per_query_context_tokens_total: List[int] = []
# 详细样本调试信息
samples: List[Dict[str, Any]] = []
# 通用指标
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
# 参考 LoCoMo 评测的类别专用 F1multi-hop 使用多答案 F1
loc_f1s: List[float] = []
# Per-category aggregation
cat_counts: Dict[str, int] = {}
cat_f1s: Dict[str, List[float]] = {}
cat_b1s: Dict[str, List[float]] = {}
cat_jss: Dict[str, List[float]] = {}
cat_loc_f1s: Dict[str, List[float]] = {}
try:
for item in items:
q = item.get("question", "")
ref = item.get("answer", "")
# 确保答案是字符串
ref_str = str(ref) if ref is not None else ""
cat = get_category_label(item)
print(f"\n=== 处理问题: {q} ===")
# 根据类别调整检索参数
search_params = get_search_params_by_category(cat)
adjusted_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
# 改进的检索逻辑使用三路检索statements, dialogues, entities
t0 = time.time()
contexts_all: List[str] = []
search_results = None # 保存完整的检索结果
try:
if search_type == "embedding":
# 直接调用嵌入检索,包含三路数据
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# 直接调用关键词检索
search_results = await search_graph(
connector=connector,
q=q,
group_id=group_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
# 构建上下文
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# 实体处理(关键词检索的实体可能没有分数)
if entities:
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# 🎯 关键修复:混合检索使用更严格的回退机制
print("🔀 使用混合检索(带回退机制)...")
try:
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# 🎯 关键修复:正确处理混合检索的扁平结构
# 新的API返回扁平结构直接从顶层获取结果
if search_results and isinstance(search_results, dict):
# 新API返回扁平结构直接从顶层获取
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# 检查是否有有效结果
if chunks or statements or entities or summaries:
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
else:
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
print(f"✅ 混合检索成功使用旧格式reranked结果: {len(chunks)} chunks, {len(statements)} 陈述")
else:
raise ValueError("混合检索返回空结果")
else:
raise ValueError("混合检索返回空结果")
except Exception as e:
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
# 🎯 统一处理:构建上下文(所有检索类型共用)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
# 关键修复:过滤掉包含当前问题答案的上下文
filtered_contexts = []
for context in contexts_all:
content = str(context)
# 排除包含当前问题标准答案的上下文
if ref_str and ref_str.strip() and ref_str.strip() in content:
print("🚫 过滤掉包含标准答案的上下文")
continue
filtered_contexts.append(context)
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
contexts_all = filtered_contexts
# 输出完整的检索结果信息
print("🔍 检索结果详情:")
if search_results:
output_data = {
"statements": [
{
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
"score": s.get("score", 0.0)
}
for s in (statements[:2] if 'statements' in locals() else [])
],
"dialogues": [
{
"uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
for d in (dialogs[:2] if 'dialogs' in locals() else [])
],
"entities": [
{
"name": e.get("name", ""),
"entity_type": e.get("entity_type", ""),
"score": e.get("score", 0.0)
}
for e in (entities[:2] if 'entities' in locals() else [])
]
}
print(json.dumps(output_data, ensure_ascii=False, indent=2))
else:
print(" 无检索结果")
except Exception as e:
print(f"{search_type}检索失败: {e}")
contexts_all = []
search_results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 使用智能上下文选择
context_text = ""
if contexts_all:
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# 记录上下文诊断信息
per_query_context_counts.append(len(contexts_all))
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
per_query_context_chars.append(len(context_text))
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
# LLM 提示词
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
# 使用异步调用
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
# 计算指标(确保使用字符串)
f1_val = common_f1(str(pred), ref_str)
b1_val = bleu1(str(pred), ref_str)
j_val = jaccard(str(pred), ref_str)
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
# Accumulate by category
cat_counts[cat] = cat_counts.get(cat, 0) + 1
cat_f1s.setdefault(cat, []).append(f1_val)
cat_b1s.setdefault(cat, []).append(b1_val)
cat_jss.setdefault(cat, []).append(j_val)
# LoCoMo 专用 F1multi-hop(1) 使用多答案 F1其它(2/3/4)使用单答案 F1
if item.get("category") in [2, 3, 4]:
loc_val = loc_f1_score(str(pred), ref_str)
elif item.get("category") in [1]:
loc_val = loc_multi_f1(str(pred), ref_str)
else:
loc_val = loc_f1_score(str(pred), ref_str)
loc_f1s.append(loc_val)
cat_loc_f1s.setdefault(cat, []).append(loc_val)
# 保存完整的检索结果信息
samples.append({
"question": q,
"answer": ref_str,
"category": cat,
"prediction": pred,
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val,
"loc_f1": loc_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": adjusted_limit,
"max_chars": max_chars
},
"timing": {
"search_ms": (t1 - t0) * 1000,
"llm_ms": (t3 - t2) * 1000
}
})
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {ref_str}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
# Compute per-category averages and dispersion (std, iqr)
def _percentile(sorted_vals: List[float], p: float) -> float:
if not sorted_vals:
return 0.0
if len(sorted_vals) == 1:
return sorted_vals[0]
k = (len(sorted_vals) - 1) * p
f = int(k)
c = f + 1 if f + 1 < len(sorted_vals) else f
if f == c:
return sorted_vals[f]
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
by_category: Dict[str, Dict[str, float | int]] = {}
for c in cat_counts:
f_list = cat_f1s.get(c, [])
b_list = cat_b1s.get(c, [])
j_list = cat_jss.get(c, [])
lf_list = cat_loc_f1s.get(c, [])
j_sorted = sorted(j_list)
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
j_q75 = _percentile(j_sorted, 0.75)
j_q25 = _percentile(j_sorted, 0.25)
by_category[c] = {
"count": cat_counts[c],
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
"j_std": j_std,
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
# 参考 LoCoMo 评测的类别专用 F1
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
}
# 累加命中cum accuracy by category与 evaluation_stats.py 输出形式相仿
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
result = {
"dataset": "locomo",
"items": len(items),
"metrics": {
"f1": sum(f1s) / max(len(f1s), 1),
"b1": sum(b1s) / max(len(b1s), 1),
"j": sum(jss) / max(len(jss), 1),
# LoCoMo 类别专用 F1 的总体
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
},
"by_category": by_category,
"category_counts": cat_counts,
"cum_accuracy_by_category": cum_accuracy_by_category,
"context": {
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"group_id": group_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
"skip_ingest_if_exists": skip_ingest_if_exists,
"llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens
},
"timestamp": datetime.now().isoformat()
}
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"❌ 保存结果失败: {e}")
return result
finally:
await connector.close()
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
args = parser.parse_args()
load_dotenv()
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
output_path=args.output_path,
skip_ingest_if_exists=args.skip_ingest_if_exists,
llm_timeout=args.llm_timeout,
llm_max_retries=args.llm_max_retries
))
print("\n" + "="*50)
print("📊 最终评测结果:")
print(f" 样本数量: {result['items']}")
print(f" F1: {result['metrics']['f1']:.3f}")
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
print(f" Jaccard: {result['metrics']['j']:.3f}")
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
if result['by_category']:
print("\n📈 按类别细分:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" 样本数: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
print(f" Jaccard: {metrics['j']:.3f}{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,324 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
if not contexts:
return ""
import re
# 提取问题关键词(移除停用词)
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
# 评分
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
# 选择直到达到字符限制,必要时截断包含关键词的段落
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
"""Compose a text context from `dialog` list in msc_self_instruct item."""
parts: List[str] = []
for turn in dialog_obj.get("dialog", []):
speaker = turn.get("speaker", "")
text = turn.get("text", "")
if text:
parts.append(f"{speaker}: {text}")
return "\n".join(parts)
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Combine dialogues from embedding and keyword searches (embedding first)."""
if results is None:
return []
emb = []
kw = []
if isinstance(results.get("embedding_search"), dict):
emb = results.get("embedding_search", {}).get("dialogues", []) or []
elif isinstance(results.get("dialogues"), list):
emb = results.get("dialogues", []) or []
if isinstance(results.get("keyword_search"), dict):
kw = results.get("keyword_search", {}).get("dialogues", []) or []
seen = set()
merged: List[Dict[str, Any]] = []
for d in emb:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
for d in kw:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id)
# LLM client (使用异步调用)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item
connector = Neo4jConnector()
latencies_llm: List[float] = []
latencies_search: List[float] = []
contexts_used: List[str] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
try:
for item in items:
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 检索:对齐 locomo 的三路检索dialogues/statements/entities
t0 = time.time()
try:
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
memory_config=memory_config,
)
except Exception:
results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
contexts_all: List[str] = []
if results:
if search_type == "hybrid":
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重与限制
seen_texts = set()
for d in all_dialogs:
text = str(d.get("content", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
for s in all_statements:
text = str(s.get("statement", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
# 实体摘要最多3个
names = []
merged_entities = all_entities[:]
for e in merged_entities:
name = str(e.get("name", "")).strip()
if name and name not in names:
names.append(name)
if len(names) >= 3:
break
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
else:
dialogs = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
for d in dialogs:
text = str(d.get("content", "")).strip()
if text:
contexts_all.append(text)
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
# 智能选择并截断到预算
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text[:200])
# Call LLM (使用异步调用)
messages = [
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))
# Aggregate metrics
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"accuracy": acc,
# Placeholders for extensibility
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"avg_context_tokens": ctx_avg_tokens,
}
return result
finally:
await connector.close()
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
args = parser.parse_args()
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -1,576 +0,0 @@
import argparse
import asyncio
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 路径与模块导入保持与现有评估脚本一致
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in (_SRC_DIR, _PROJECT_ROOT):
if _p not in sys.path:
sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
tp = len(set(ps) & set(rs))
if tp == 0:
return 0.0
precision = tp / len(ps)
recall = tp / len(rs)
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
overlap = len([w for w in ps if w in rs])
return overlap / max(len(ps), 1)
def jaccard(pred: str, ref: str) -> float:
ps = set(pred.lower().split())
rs = set(ref.lower().split())
union = len(ps | rs)
if union == 0:
return 0.0
return len(ps & rs) / union
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
"""
if not contexts:
return ""
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3"""
ql = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
}
words = re.findall(r"\b[\w-]+\b", ql)
kws = [w for w in words if w not in stop_words and len(w) >= 3]
# 去重保序
seen = set()
uniq = []
for w in kws:
if w not in seen:
uniq.append(w)
seen.add(w)
if len(uniq) >= max_keywords:
break
return uniq
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
"""对上下文进行简单相关性打分,仅用于控制台可视化。
评分: score = match_count*200 + min(len(text), 100000)/100
"""
results = []
for ctx in contexts:
tl = (ctx or "").lower()
match_count = sum(1 for k in keywords if k in tl)
length = len(ctx)
score = match_count * 200 + min(length, 100000) / 100.0
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
return results[:max(top_n, 0)]
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
if not os.path.exists(data_path):
raise FileNotFoundError(f"未找到数据集: {data_path}")
items: List[Dict[str, Any]] = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
items.append(json.loads(line))
except Exception:
# 跳过坏行但不中断
continue
return items
async def run_memsciqa_test(
sample_size: int = 3,
group_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
llm_max_tokens: int = 64,
search_type: str = "embedding",
data_path: str | None = None,
start_index: int = 0,
verbose: bool = True,
) -> Dict[str, Any]:
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
- 支持从指定索引开始与评估全部样本sample_size<=0
- 支持在摄入前重置组(清空图)与跳过摄入
- 支持 keyword / embedding / hybrid 三种检索
"""
# 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底)
if not data_path:
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
if os.path.exists(proj_path):
data_path = proj_path
elif os.path.exists(cwd_path):
data_path = cwd_path
else:
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl请确保其存在于项目根目录或当前工作目录的 data 目录下。")
# 加载数据
all_items = load_dataset_memsciqa(data_path)
if sample_size is None or sample_size <= 0:
items = all_items[start_index:]
else:
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 评估循环
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 存储完整上下文文本用于统计
contexts_used: List[str] = []
per_query_context_chars: List[int] = []
per_query_context_counts: List[int] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
samples: List[Dict[str, Any]] = []
total_items = len(items)
for idx, item in enumerate(items):
if verbose:
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 三路检索chunks/statements/entities/summaries对齐 qwen_search_eval.py
t0 = time.time()
results = None
try:
if search_type in ("embedding", "hybrid"):
# 使用嵌入检索(与 qwen_search_eval 对齐)
results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
elif search_type == "keyword":
# 关键词检索(直接调用 graph_search
results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
except Exception:
results = None
t1 = time.time()
search_ms = (t1 - t0) * 1000
latencies_search.append(search_ms)
# 构建上下文:包含 chunks、陈述、摘要和实体对齐 qwen_search_eval.py
contexts_all: List[str] = []
retrieved_counts: Dict[str, int] = {}
if results:
chunks = results.get("chunks", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
summaries = results.get("summaries", [])
retrieved_counts = {
"chunks": len(chunks),
"statements": len(statements),
"entities": len(entities),
"summaries": len(summaries),
}
# 优先使用 chunks
for c in chunks:
text = str(c.get("content", "")).strip()
if text:
contexts_all.append(text)
# 然后是 statements
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
# 然后是 summaries
for sm in summaries:
text = str(sm.get("summary", "")).strip()
if text:
contexts_all.append(text)
# 实体摘要最多加入前3个高分实体对齐 qwen_search_eval.py
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
if verbose:
if retrieved_counts:
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
print(f"📊 有效上下文数量: {len(contexts_all)}")
q_keywords = extract_question_keywords(question, max_keywords=8)
if q_keywords:
print(f"🔍 问题关键词: {set(q_keywords)}")
if contexts_all:
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
if analysis:
print("📊 上下文相关性分析:")
for a in analysis:
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
# 打印检索到的上下文预览,便于定位为何为 Unknown
print("🔎 上下文预览最多前10条每条截断展示:")
for i, ctx in enumerate(contexts_all[:10]):
preview = str(ctx).replace("\n", " ")
if len(preview) > 300:
preview = preview[:300] + "..."
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
# 标注参考答案是否出现在任一上下文中
ref_lower = (str(reference) or "").lower()
if ref_lower:
hits = []
for i, ctx in enumerate(contexts_all):
if ref_lower in str(ctx).lower():
hits.append(i+1)
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text)
per_query_context_chars.append(len(context_text))
per_query_context_counts.append(len(contexts_all))
if verbose:
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
# 展示拼接后的上下文片段,便于核查是否包含答案
concat_preview = context_text.replace("\n", " ")
if len(concat_preview) > 600:
concat_preview = concat_preview[:600] + "..."
print(f"🧵 拼接上下文预览: {concat_preview}")
messages = [
{
"role": "system",
"content": (
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
"3) Keep your answer brief and to the point;\n"
"4) Do not add explanations or additional text beyond the answer."
),
},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 更健壮的响应解析处理不同的LLM响应格式
if hasattr(resp, 'content'):
pred = resp.content.strip()
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
pred = resp["choices"][0]["message"]["content"].strip()
elif isinstance(resp, dict) and "content" in resp:
pred = resp["content"].strip()
elif isinstance(resp, str):
pred = resp.strip()
else:
pred = "Unknown"
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
if pred.lower() in ["unknown", ""]:
# 如果参考答案在上下文中存在但LLM返回Unknown可能是提示词问题
ref_lower = (str(reference) or "").lower()
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown检查提示词")
except Exception as e:
# 更详细的错误处理
pred = "Unknown"
print(f"⚠️ LLM调用异常: {e}")
t3 = time.time()
llm_ms = (t3 - t2) * 1000
latencies_llm.append(llm_ms)
exact = exact_match(pred, reference)
correct_flags.append(exact)
f1_val = f1_score(str(pred), str(reference))
b1_val = bleu1(str(pred), str(reference))
j_val = jaccard(str(pred), str(reference))
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
if verbose:
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {reference}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
samples.append({
"question": str(question),
"answer": str(reference),
"prediction": str(pred),
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": context_char_budget
},
"timing": {
"search_ms": search_ms,
"llm_ms": llm_ms
}
})
# 计算总体指标与聚合
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"context": {
"avg_tokens": ctx_avg_tokens,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": 0.0
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"group_id": group_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens,
"search_type": search_type,
"start_index": start_index,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
},
"timestamp": datetime.now().isoformat(),
}
try:
await connector.close()
except Exception:
pass
return result
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci")
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型hybrid 等同于 embedding")
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl")
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径JSON")
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
args = parser.parse_args()
sample_size = 0 if args.all else args.sample_size
verbose_flag = False if args.quiet else args.verbose
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
data_path=args.data_path,
start_index=args.start_index,
verbose=verbose_flag,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果保存
out_path = args.output
if not out_path:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, "results")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
try:
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 结果已保存: {out_path}")
except Exception as e:
print(f"⚠️ 结果保存失败: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,150 +0,0 @@
import argparse
import asyncio
import json
import os
import sys
from typing import Any, Dict
# Add src directory to Python path for proper imports when running from evaluation directory
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
async def run(
dataset: str,
sample_size: int,
reset_group: bool,
group_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
llm_temperature: float | None = None,
llm_max_tokens: int | None = None,
search_type: str | None = None,
start_index: int | None = None,
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(group_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
if start_index is not None:
kwargs["start_index"] = start_index
if max_contexts_per_item is not None:
kwargs["max_contexts_per_item"] = max_contexts_per_item
return await run_longmemeval_test(**kwargs)
raise ValueError(f"未知数据集: {dataset}")
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens不提供则使用各脚本默认")
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
# 仅透传到 longmemeval其他数据集忽略
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval起始样本索引不提供则用脚本默认")
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval每条样本摄入的上下文数量上限不提供则用脚本默认")
parser.add_argument("--output", type=str, default=None, help="可选将评估结果保存到指定文件路径JSON不提供时默认保存到 evaluation/<dataset>/results 目录")
args = parser.parse_args()
result = asyncio.run(run(
args.dataset,
args.sample_size,
args.reset_group,
args.group_id,
args.judge_model,
args.search_limit,
args.context_char_budget,
args.llm_temperature,
args.llm_max_tokens,
args.search_type,
args.start_index,
args.max_contexts_per_item,
))
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果输出逻辑保持不变
if args.output:
out_path = args.output
else:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
out_filename = f"{args.dataset}_{args.sample_size}.json"
out_path = os.path.join(dataset_results_dir, out_filename)
out_dir = os.path.dirname(out_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n结果已保存到: {out_path}")
if __name__ == "__main__":
main()

View File

@@ -187,11 +187,11 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData):
"""
Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
"""
@@ -201,9 +201,9 @@ class ChunkerClient:
f"Dialogue {dialogue.ref_id} has no messages. "
f"Cannot generate chunks from empty dialogue."
)
dialogue.chunks = []
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
@@ -212,13 +212,13 @@ class ChunkerClient:
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
@@ -228,14 +228,14 @@ class ChunkerClient:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
@@ -260,7 +260,7 @@ class ChunkerClient:
},
)
dialogue.chunks.append(chunk)
# Validate we generated at least one chunk
if not dialogue.chunks:
raise ValueError(
@@ -268,7 +268,7 @@ class ChunkerClient:
f"All messages were either empty or too short. "
f"Messages count: {len(dialogue.context.msgs)}"
)
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph.
Attributes:
group_id: Group ID to filter search results (default: 'test')
end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results
user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3)
"""
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge
source: ID of the source node
target: ID of the target node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes:
id: Unique identifier for the node
name: Name of the node
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective)
"""
id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
end_user_id: str = Field(..., description="The end user ID of the node.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")

View File

@@ -55,7 +55,7 @@ class Statement(BaseModel):
Attributes:
id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement
@@ -73,7 +73,7 @@ class Statement(BaseModel):
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
end_user_id: End user ID for multi-tenancy
created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
return []
def assign_group_id_to_statements(self) -> None:
"""Assign this dialog's group_id to all statements in all chunks.
"""Assign this dialog's end_user_id to all statements in all chunks.
This method updates statements that don't have a group_id set.
This method updates statements that don't have a end_user_id set.
"""
for chunk in self.chunks:
for statement in chunk.statements:
if statement.group_id is None:
statement.group_id = self.group_id
if statement.end_user_id is None:
statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering
end_user_id: Group identifier for filtering
limit: Maximum number of results
include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}"
f"end_user_id={end_user_id}, limit={limit}, include={include}"
)
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search(
query_text: str,
search_type: str,
group_id: str | None,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
@@ -715,7 +716,7 @@ async def run_hybrid_search(
}
# Log the search query
log_search_query(query_text, search_type, group_id, limit, include)
log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector()
results = {}
@@ -732,7 +733,7 @@ async def run_hybrid_search(
search_graph(
connector=connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include
)
@@ -769,7 +770,7 @@ async def run_hybrid_search(
connector=connector,
embedder_client=embedder,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
)
@@ -916,9 +917,7 @@ async def run_hybrid_search(
async def search_by_temporal(
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -929,7 +928,7 @@ async def search_by_temporal(
Temporal search across Statements.
- Matches statements created between start_date and end_date
- Optionally filters by group_id
- Optionally filters by end_user_id
- Returns up to 'limit' statements
"""
connector = Neo4jConnector()
@@ -939,9 +938,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -950,9 +947,7 @@ async def search_by_temporal(
})
statements = await search_graph_by_temporal(
connector=connector,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -964,9 +959,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
group_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({
"group_id": group_id,
"apply_id": apply_id,
"user_id": user_id,
"end_user_id": end_user_id,
"start_date": start_date,
"end_date": end_date,
"valid_date": valid_date,
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal(
connector=connector,
query_text=query_text,
group_id=params.group_id,
apply_id=params.apply_id,
user_id=params.user_id,
end_user_id=params.end_user_id,
start_date=params.start_date,
end_date=params.end_date,
valid_date=params.valid_date,
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
group_id: Optional[str] = "test",
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id(
connector=connector,
chunk_id=chunk_id,
group_id=group_id,
end_user_id=end_user_id,
limit=limit
)
return {"chunks": chunks}

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}')
# 获取end_user_id如果不存在则生成默认值
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}')
end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData(
context=context,
ref_id=dialog_id,
group_id=group_id,
end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
"""
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map)
"""
exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map:
canonical_map[key] = ent
id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map:
exact_merge_map[k] = {
"canonical_id": canonical.id,
"group_id": canonical.group_id,
"end_user_id": canonical.end_user_id,
"name": canonical.name,
"entity_type": canonical.entity_type,
"merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j]
# 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1
continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
)
except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志
try:
llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id
try:
disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
)
except Exception:
pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception:
pass
ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
# 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
# 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
"""
group_id 分块,避免跨组实体在同一块,减少无效候选对
end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args:
nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
"""
groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes:
gid = getattr(e, "group_id", None)
gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理
# 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环
break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i]
b = entity_nodes[j]
# 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue
ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode(
id=row.get("id"),
name=row.get("name") or "",
group_id=row.get("group_id") or "",
end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
"""
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id,
connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
)

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
# 先探测 end_user_id决定报告写入策略
end_user_id: Optional[str] = None
for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None)
if group_id:
end_user_id = getattr(dd, "end_user_id", None)
if end_user_id:
break
# 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
try:
if group_id:
if end_user_id:
if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector,
group_id=group_id,
end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else:
print("Skip second-layer dedup: missing connector")
else:
print("Skip second-layer dedup: missing group_id")
print("Skip second-layer dedup: missing end_user_id")
except Exception as e:
print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content))
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data
chunk, end_user_id, dialogue_content = chunk_data
try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id
# 加载DataConfig
data_config = None
# 加载MemoryConfig
memory_config = None
if config_id:
try:
from app.db import SessionLocal
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
db = SessionLocal()
try:
data_config = DataConfigRepository.get_by_id(db, config_id)
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
finally:
db.close()
if data_config and not data_config.emotion_enabled:
if memory_config and not memory_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list]
except Exception as e:
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list]
else:
logger.info("未找到config_id跳过情绪提取")
return [{} for _ in dialog_data_list]
# 如果配置未启用情绪提取,直接返回空映射
if not data_config or not data_config.emotion_enabled:
if not memory_config or not memory_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
total_statements += 1
# 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
all_statements.append((statement, memory_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
emotion_service = EmotionExtractionService(
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
)
# 全局并行处理所有陈述句
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content,
chunk_embedding=chunk.chunk_embedding,
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge(
source=statement.id,
target=chunk.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1072,13 +1064,16 @@ class ExtractionOrchestrator:
if statement.triplet_extraction_info:
triplet_info = statement.triplet_extraction_info
# 创建实体索引到ID的映射
# 创建实体索引到ID的映射(支持多种索引方式)
entity_idx_to_id = {}
# 创建实体节点
for entity_idx, entity in enumerate(triplet_info.entities):
# 映射实体索引到实体ID
# 映射实体索引到实体ID(使用多个键以提高容错性)
# 1. 使用实体自己的 entity_idx
entity_idx_to_id[entity.entity_idx] = entity.id
# 2. 使用枚举索引从0开始
entity_idx_to_id[entity_idx] = entity.id
if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
@@ -1095,9 +1090,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1112,9 +1105,7 @@ class ExtractionOrchestrator:
source=statement.id,
target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
)
@@ -1134,9 +1125,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate,
statement=statement.statement,
source_statement_id=statement.id,
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
@@ -1163,9 +1152,18 @@ class ExtractionOrchestrator:
relationship_result
)
else:
logger.warning(
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
f"object_id={triplet.object_id}, statement_id={statement.id}"
# 改进的警告信息,包含更多调试信息
missing_subject = "subject" if not subject_entity_id else ""
missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
f"object_id={triplet.object_id} ({triplet.object_name}), "
f"predicate={triplet.predicate}, "
f"statement_id={statement.id}, "
f"available_indices={sorted(entity_idx_to_id.keys())}"
)
logger.info(
@@ -1763,14 +1761,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
end_user_id: str = "group_1",
indices: Optional[List[int]] = None,
) -> List[DialogData]:
"""从测试数据生成分块对话
Args:
chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID
end_user_id: 组ID
indices: 要处理的数据索引列表(可选)
Returns:
@@ -1834,7 +1832,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData(
context=conversation_context,
ref_id=data['id'],
group_id=group_id,
end_user_id=end_user_id,
metadata=dialog_metadata,
)
@@ -1936,7 +1934,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "default",
end_user_id: str = "default",
user_id: str = "default",
apply_id: str = "default",
indices: Optional[List[int]] = None,
@@ -1948,7 +1946,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args:
chunker_strategy: 分块策略
group_id: 组ID
end_user_id: 组ID
user_id: 用户ID
apply_id: 应用ID
indices: 要处理的数据索引列表
@@ -1976,11 +1974,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices,
)
# 设置 group_id, user_id, apply_id
# 设置 end_user_id
for dd in preprocessed_data:
dd.group_id = group_id
dd.user_id = user_id
dd.apply_id = apply_id
dd.end_user_id = end_user_id
# 步骤2: 语义剪枝
try:

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode(
id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id,
user_id=dialog.user_id,
apply_id=dialog.apply_id,
end_user_id=dialog.end_user_id,
user_id=dialog.end_user_id,
apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(),
expired_at=datetime(9999, 12, 31),

View File

@@ -82,12 +82,12 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
Args:
chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk
end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context
Returns:
@@ -158,7 +158,7 @@ class StatementExtractor:
temporal_info=temporal_type,
relevence_info=relevence_info,
chunk_id=chunk.id,
group_id=group_id,
end_user_id=end_user_id,
speaker=chunk_speaker,
)
@@ -184,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True
)
@@ -225,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n")
f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -298,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({
"dialog_id": dialog.ref_id,
"group_id": dialog.group_id,
"end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations,
"weak": weak_relations,
@@ -312,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n")
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n")
f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
)
except Exception:
pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try:
prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
)
except Exception:
pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
end_user_id: 组ID可选
current_time: 当前时间(可选)
Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
end_user_id=end_user_id,
current_time=current_time
)
tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args:
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
limit: 检查的最大节点数
Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
if end_user_id:
params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
end_user_id=end_user_id
)
if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
SET n += $repair_data
RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
if end_user_id:
query += " WHERE n.end_user_id = $end_user_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
end_user_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
if end_user_id:
read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """
RETURN n.id as id,
n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
if end_user_id:
read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句
where_conditions = []
if group_id:
where_conditions.append("n.group_id = $group_id")
if end_user_id:
where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查
if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
if end_user_id:
update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
end_user_id=end_user_id
)
return result
except Exception as e:

View File

@@ -11,9 +11,10 @@ Functions:
import logging
from typing import Optional, Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
def load_actr_config_from_db(
db: Session,
config_id: Optional[int] = None
config_id: Optional[UUID] = None
) -> Dict[str, Any]:
"""
从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数,
从 PostgreSQL 的 memory_config 表读取配置参数,
并计算派生参数(如 forgetting_rate
Args:
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
# 从数据库加载配置
try:
repository = DataConfigRepository()
repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None:
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
def create_actr_calculator_from_config(
db: Session,
config_id: Optional[int] = None
config_id: Optional[UUID] = None
) -> ACTRCalculator:
"""
从数据库配置创建 ACTRCalculator 实例
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
ValueError: 如果指定的 config_id 不存在
Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)

View File

@@ -16,6 +16,7 @@ Classes:
import logging
from typing import Dict, Any, Optional
from uuid import UUID
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
@@ -66,10 +67,10 @@ class ForgettingScheduler:
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> Dict[str, Any]:
"""
@@ -77,7 +78,7 @@ class ForgettingScheduler:
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +108,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
end_user_id=end_user_id,
min_days_since_access=min_days_since_access
)
@@ -213,7 +214,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
entity_node = {
@@ -222,7 +223,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
'end_user_id': end_user_id
}
# 融合节点
@@ -262,7 +263,7 @@ class ForgettingScheduler:
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
@@ -315,7 +316,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
end_user_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
@@ -323,7 +324,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
@@ -333,16 +334,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
if end_user_id:
query += " AND n.end_user_id = $end_user_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)

View File

@@ -13,6 +13,7 @@ Classes:
import logging
from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -90,7 +91,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
@@ -102,7 +103,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
@@ -136,8 +137,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
if end_user_id:
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """
RETURN s.id as statement_id,
@@ -159,8 +160,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
if end_user_id:
params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params)
@@ -176,7 +177,7 @@ class ForgettingStrategy:
self,
statement_node: Dict[str, Any],
entity_node: Dict[str, Any],
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> str:
"""
@@ -247,8 +248,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 获取 end_user_id从 statement 或 entity 节点)
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容
summary_text = await self._generate_summary(
@@ -325,7 +326,7 @@ class ForgettingStrategy:
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
end_user_id: $end_user_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
@@ -423,7 +424,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
'end_user_id': end_user_id
}
try:
@@ -462,7 +463,7 @@ class ForgettingStrategy:
statement_text: str,
entity_name: str,
entity_type: str,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
db = None
) -> str:
"""
@@ -527,7 +528,7 @@ class ForgettingStrategy:
statement_text, entity_name, entity_type
)
async def _get_llm_client(self, db, config_id: int):
async def _get_llm_client(self, db, config_id: UUID):
"""
从数据库获取 LLM 客户端
@@ -539,11 +540,11 @@ class ForgettingStrategy:
LLM 客户端实例,如果无法获取则返回 None
"""
try:
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置
repository = DataConfigRepository()
repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search(
query_text: str,
search_type: str = "hybrid",
group_id: str | None = None,
end_user_id: str | None = None,
apply_id: str | None = None,
user_id: str | None = None,
limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args:
query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤
end_user_id: 组ID过滤
apply_id: 应用ID过滤
user_id: 用户ID过滤
limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索
result = await strategy.search(
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include,
alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search(
# self,
# query_text: str,
# group_id: Optional[str] = None,
# end_user_id: Optional[str] = None,
# limit: int = 50,
# include: Optional[List[str]] = None,
# **kwargs
@@ -86,7 +86,7 @@
# Args:
# query_text: 查询文本
# group_id: 可选的组ID过滤
# end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns:
# SearchResult: 搜索结果对象
# """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
# semantic_result = await self.semantic_strategy.search(
# query_text=query_text,
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list
# )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# include=include_list,
# alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# end_user_id=end_user_id,
# limit=limit,
# error=str(e)
# )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph(
connector=self.connector,
q=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="keyword",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self,
query_text: str,
search_type: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
**kwargs
) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args:
query_text: 查询文本
search_type: 搜索类型
group_id: 组ID
end_user_id: 组ID
limit: 结果限制
**kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = {
"query": query_text,
"search_type": search_type,
"group_id": group_id,
"end_user_id": end_user_id,
"limit": limit,
"timestamp": datetime.now().isoformat()
}

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别
include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector,
embedder_client=self.embedder_client,
query_text=query_text,
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
include=include_list
)
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata(
query_text=query_text,
search_type="semantic",
group_id=group_id,
end_user_id=end_user_id,
limit=limit,
error=str(e)
)

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [
"id",
"statement",
"group_id",
"end_user_id",
"chunk_id",
"created_at",
"expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
"""
EXCLUDE_FIELDS = {
"user_id",
"group_id",
"end_user_id",
"entity_type",
"connect_strength",
"relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self,
config_id: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
end_user_id: Optional[str] = None,
success: bool = True,
details: Optional[Dict[str, Any]] = None
):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args:
config_id: 配置 ID
user_id: 用户 ID可选
group_id: 组 ID可选
end_user_id: 组 ID可选
success: 是否成功
details: 详细信息(可选)
"""
result = "SUCCESS" if success else "FAILED"
msg = (
f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}"
)
if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self,
operation: str,
config_id: str,
group_id: str,
end_user_id: str,
success: bool = True,
duration: Optional[float] = None,
error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args:
operation: 操作类型WRITE, READ 等)
config_id: 配置 ID
group_id: 组 ID
end_user_id: 组 ID
success: 是否成功
duration: 操作耗时(秒)
error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED"
msg = (
f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}"
f"group={end_user_id} result={result}"
)
if duration is not None:
msg += f" duration={duration:.2f}s"

View File

@@ -1,165 +0,0 @@
import copy
import re
from io import BytesIO
from PIL import Image
from app.core.rag.nlp import tokenize, is_english
from app.core.rag.nlp import rag_tokenizer
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
from app.core.rag.deepdoc.parser.ppt_parser import RAGPptParser as PptParser
from PyPDF2 import PdfReader as pdf2_read
from app.core.rag.app.naive import by_plaintext, PARSERS
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
txts = super().__call__(fnm, from_page, to_page)
callback(0.5, "Text extraction finished.")
import aspose.slides as slides
import aspose.pydrawing as drawing
imgs = []
with slides.Presentation(BytesIO(fnm)) as presentation:
for i, slide in enumerate(presentation.slides[from_page: to_page]):
try:
with BytesIO() as buffered:
slide.get_thumbnail(
0.1, 0.1).save(
buffered, drawing.imaging.ImageFormat.jpeg)
buffered.seek(0)
imgs.append(Image.open(buffered).copy())
except RuntimeError as e:
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
callback(0.9, "Image extraction finished")
self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __garbage(self, txt):
txt = txt.lower().strip()
if re.match(r"[0-9\.,%/-]+$", txt):
return True
if len(txt) < 3:
return True
return False
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(filename if not binary else binary,
zoomin, from_page, to_page, callback)
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
len(self.boxes), len(self.page_images))
res = []
for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i]
if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res, []
class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
page_txt = []
for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt], []
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
"""
The supported file formats are pdf, pptx.
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
"""
if parser_config is None:
parser_config = {}
eng = lang.lower() == "english"
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
ppt_parser = Ppt()
for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, 1000000, callback)):
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["doc_type_kwd"] = "image"
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, eng)
res.append(d)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, _, _ = parser(
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
vision_model=vision_model,
pdf_cls=Pdf,
**kwargs
)
if not sections:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc)
pn += from_page
if img:
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng)
res.append(d)
return res
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
GROUP_KEY = "end_user_id"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = auto()

View File

@@ -36,7 +36,7 @@ def generate_signed_url(
"""
if base_url is None:
# Use SERVER_IP or default to localhost
server_url = f"http://{settings.SERVER_IP}:8000/api"
server_url = settings.FILE_LOCAL_SERVER_URL
base_url = server_url
# Calculate expiration timestamp

View File

@@ -16,7 +16,7 @@ class BaiduSearchTool(BuiltinTool):
@property
def description(self) -> str:
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、视频搜索"
def get_required_config_parameters(self) -> List[str]:
return ["api_key"]
@@ -33,7 +33,7 @@ class BaiduSearchTool(BuiltinTool):
ToolParameter(
name="search_type",
type=ParameterType.STRING,
description="搜索类型",
description="搜索类型, web: 网页搜索news新闻搜索image图片搜索video视频搜索",
required=False,
default="web",
enum=["web", "news", "image", "video"]

View File

@@ -26,7 +26,7 @@ logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID."""
if model_id is None:
return None
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> tuple[str, bool]:
"""Validate that a model exists and is active.
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
db: Session,
tenant_id: Optional[UUID] = None,
required: bool = False,
config_id: Optional[int] = None,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status.
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
def validate_embedding_model(
config_id: int,
config_id: UUID,
embedding_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
@@ -256,7 +256,7 @@ def validate_embedding_model(
def validate_llm_model(
config_id: int,
config_id: UUID,
llm_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,

View File

@@ -11,16 +11,12 @@ from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter
# TOOL_MANAGEMENT_AVAILABLE = True
# from app.db import get_db
from app.core.workflow.template_renderer import render_template
logger = logging.getLogger(__name__)
@@ -55,6 +51,8 @@ class WorkflowExecutor:
self.execution_config = workflow_config.get("execution_config", {})
self.start_node_id = None
self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None
self.checkpoint_config = RunnableConfig(
configurable={
@@ -127,7 +125,6 @@ class WorkflowExecutor:
"user_id": self.user_id,
"error": None,
"error_node": None,
"streaming_buffer": {}, # 流式缓冲区
"cycle_nodes": [
node.get("id")
for node in self.workflow_config.get("nodes")
@@ -139,9 +136,8 @@ class WorkflowExecutor:
}
}
def _build_final_output(self, result, elapsed_time):
def _build_final_output(self, result, elapsed_time, final_output):
node_outputs = result.get("node_outputs", {})
final_output = self._extract_final_output(node_outputs)
token_usage = self._aggregate_token_usage(node_outputs)
conversation_id = None
for node_id, node_output in node_outputs.items():
@@ -161,6 +157,21 @@ class WorkflowExecutor:
"error": result.get("error"),
}
def _update_end_activate(self, node_id):
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(node_id)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
@staticmethod
def _trans_output_string(content):
if isinstance(content, str):
return content
elif isinstance(content, list):
return "\n".join(content)
else:
return str(content)
def build_graph(self, stream=False) -> CompiledStateGraph:
"""构建 LangGraph
@@ -173,6 +184,7 @@ class WorkflowExecutor:
stream=stream,
)
self.start_node_id = builder.start_node_id
self.end_outputs = builder.end_node_map
graph = builder.build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
@@ -205,14 +217,34 @@ class WorkflowExecutor:
try:
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
full_content = ''
for end_info in self.end_outputs.values():
output_template = "".join([output.literal for output in end_info.outputs])
full_content += render_template(
output_template,
result.get("variables", {}),
result.get("runtime_vars", {}),
strict=False
)
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return self._build_final_output(result, elapsed_time)
return self._build_final_output(result, elapsed_time, full_content)
except Exception as e:
# 计算耗时(即使失败也记录)
@@ -261,7 +293,7 @@ class WorkflowExecutor:
"data": {
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"timestamp": start_time.isoformat()
"timestamp": int(start_time.timestamp() * 1000)
}
}
@@ -273,6 +305,7 @@ class WorkflowExecutor:
# 3. Execute workflow
try:
chunk_count = 0
full_content = ''
async for event in graph.astream(
initial_state,
@@ -293,20 +326,39 @@ class WorkflowExecutor:
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
f"- execution_id: {self.execution_id}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix"),
"conversation_id": input_data.get("conversation_id"),
if event_type == "node_chunk":
node_id = data.get("node_id")
if self.activate_end:
end_info = self.end_outputs.get(self.activate_end)
if not end_info or end_info.cursor >= len(end_info.outputs):
continue
current_output = end_info.outputs[end_info.cursor]
if current_output.is_variable and current_output.depends_on_node(node_id):
if data.get("done"):
end_info.cursor += 1
else:
full_content += data.get("chunk")
yield {
"event": "message",
"data": {
"chunk": data.get("chunk")
}
}
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
f"- execution_id: {self.execution_id}")
elif event_type == "node_error":
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"status": "failed",
"input": data.get("input_data"),
"elapsed_time": data.get("elapsed_time"),
"output": None,
"error": data.get("error")
}
}
}
elif mode == "debug":
# Handle debug information (node execution status)
@@ -325,14 +377,15 @@ class WorkflowExecutor:
conversation_id = input_data.get("conversation_id")
logger.info(f"[NODE-START] Node starts execution: {node_name} "
f"- execution_id: {self.execution_id}")
yield {
"event": "node_start",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
"timestamp": int(datetime.datetime.fromisoformat(
data.get("timestamp")
).timestamp() * 1000),
}
}
elif event_type == "task_result":
@@ -351,21 +404,120 @@ class WorkflowExecutor:
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
"state": result.get("node_outputs", {}).get(node_name),
"timestamp": int(datetime.datetime.fromisoformat(
data.get("timestamp")
).timestamp() * 1000),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
}
}
elif mode == "updates":
# Handle state updates - store final state
# TODO:流式输出点
for node_id in data.keys():
self._update_end_activate(node_id)
wait = False
state = graph.get_state(config=self.checkpoint_config)
node_outputs = state.values.get("runtime_vars", {})
for _ in data.keys():
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
while self.activate_end and not wait:
message = ''
logger.info(self.activate_end)
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
while content.activate:
if not content.is_variable:
full_content += content.literal
message += content.literal
else:
try:
chunk = evaluate_expression(
content.literal,
variables={},
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
break
content = end_info.outputs[end_info.cursor]
if end_info.cursor != len(end_info.outputs):
wait = True
else:
self.end_outputs.pop(self.activate_end)
self.activate_end = None
for node_id in data.keys():
self._update_end_activate(node_id)
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
f"- execution_id: {self.execution_id}")
result = graph.get_state(self.checkpoint_config).values
while self.activate_end:
message = ''
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
if not content.is_variable:
message += content.literal
else:
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
try:
chunk = evaluate_expression(
content.literal,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
if self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
result = graph.get_state(self.checkpoint_config).values
logger.info(result)
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
@@ -374,7 +526,7 @@ class WorkflowExecutor:
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
"data": self._build_final_output(result, elapsed_time)
"data": self._build_final_output(result, elapsed_time, full_content)
}
except Exception as e:
@@ -396,31 +548,6 @@ class WorkflowExecutor:
}
}
@staticmethod
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
@staticmethod
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
@@ -511,178 +638,3 @@ async def execute_workflow_stream(
)
async for event in executor.execute_stream(input_data):
yield event
# ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
# """获取工作流可用的工具列表
#
# Args:
# workspace_id: 工作空间ID
# user_id: 用户ID
#
# Returns:
# 可用工具列表
# """
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.warning("工具管理系统不可用")
# return []
#
# try:
# db = next(get_db())
#
# # 创建工具注册表
# registry = ToolRegistry(db)
#
# # 注册内置工具类
# from app.core.tools.builtin import (
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
# )
# registry.register_tool_class(DateTimeTool)
# registry.register_tool_class(JsonTool)
# registry.register_tool_class(BaiduSearchTool)
# registry.register_tool_class(MinerUTool)
# registry.register_tool_class(TextInTool)
#
# # 获取活跃的工具
# import uuid
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
# active_tools = [tool for tool in tools if tool.status.value == "active"]
#
# # 转换为Langchain工具
# langchain_tools = []
# for tool_info in active_tools:
# try:
# tool_instance = registry.get_tool(tool_info.id)
# if tool_instance:
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
# langchain_tools.append(langchain_tool)
# except Exception as e:
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
#
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
# return langchain_tools
#
# except Exception as e:
# logger.error(f"获取工作流工具失败: {e}")
# return []
#
#
# class ToolWorkflowNode:
# """工具工作流节点 - 在工作流中执行工具"""
#
# def __init__(self, node_config: dict, workflow_config: dict):
# """初始化工具节点
#
# Args:
# node_config: 节点配置
# workflow_config: 工作流配置
# """
# self.node_config = node_config
# self.workflow_config = workflow_config
# self.tool_id = node_config.get("tool_id")
# self.tool_parameters = node_config.get("parameters", {})
#
# async def run(self, state: WorkflowState) -> WorkflowState:
# """执行工具节点"""
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.error("工具管理系统不可用")
# state["error"] = "工具管理系统不可用"
# return state
#
# try:
# from sqlalchemy.orm import Session
# db = next(get_db())
#
# # 创建工具执行器
# registry = ToolRegistry(db)
# executor = ToolExecutor(db, registry)
#
# # 准备参数(支持变量替换)
# parameters = self._prepare_parameters(state)
#
# # 执行工具
# result = await executor.execute_tool(
# tool_id=self.tool_id,
# parameters=parameters,
# user_id=uuid.UUID(state["user_id"]),
# workspace_id=uuid.UUID(state["workspace_id"])
# )
#
# # 更新状态
# node_id = self.node_config.get("id")
# if result.success:
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "output": result.data,
# "execution_time": result.execution_time,
# "token_usage": result.token_usage
# }
#
# # 更新运行时变量
# if isinstance(result.data, dict):
# for key, value in result.data.items():
# state["runtime_vars"][f"{node_id}.{key}"] = value
# else:
# state["runtime_vars"][f"{node_id}.result"] = result.data
# else:
# state["error"] = result.error
# state["error_node"] = node_id
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "error": result.error,
# "execution_time": result.execution_time
# }
#
# return state
#
# except Exception as e:
# logger.error(f"工具节点执行失败: {e}")
# state["error"] = str(e)
# state["error_node"] = self.node_config.get("id")
# return state
#
# def _prepare_parameters(self, state: WorkflowState) -> dict:
# """准备工具参数(支持变量替换)"""
# parameters = {}
#
# for key, value in self.tool_parameters.items():
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# # 变量替换
# var_path = value[2:-1]
#
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
# if "." in var_path:
# parts = var_path.split(".")
# current = state.get("variables", {})
#
# for part in parts:
# if isinstance(current, dict) and part in current:
# current = current[part]
# else:
# # 尝试从运行时变量获取
# runtime_key = ".".join(parts)
# current = state.get("runtime_vars", {}).get(runtime_key, value)
# break
#
# parameters[key] = current
# else:
# # 简单变量
# variables = state.get("variables", {})
# parameters[key] = variables.get(var_path, value)
# else:
# parameters[key] = value
#
# return parameters
#
#
# # 注册工具节点到NodeFactory如果存在
# try:
# from app.core.workflow.nodes import NodeFactory
# if hasattr(NodeFactory, 'register_node_type'):
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
# logger.info("工具节点已注册到工作流系统")
# except Exception as e:
# logger.warning(f"注册工具节点失败: {e}")

View File

@@ -1,12 +1,15 @@
import logging
import re
import uuid
from collections import defaultdict
from functools import lru_cache
from typing import Any
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
@@ -15,6 +18,153 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
logger = logging.getLogger(__name__)
class OutputContent(BaseModel):
"""
Represents a single output segment of an End node.
An output segment can be either:
- literal text (static string)
- a variable placeholder (e.g. {{ node.field }})
Each segment has its own activation state, which is especially
important in stream mode.
"""
literal: str = Field(
...,
description="Raw output content. Can be literal text or a variable placeholder."
)
activate: bool = Field(
...,
description=(
"Whether this output segment is currently active.\n"
"- True: allowed to be emitted/output\n"
"- False: blocked until activated by branch control"
)
)
is_variable: bool = Field(
...,
description=(
"Whether this segment represents a variable placeholder.\n"
"True -> variable (e.g. {{ node.field }})\n"
"False -> literal text"
)
)
def depends_on_node(self, node_id: str) -> bool:
"""
Check if this output segment depends on a specific node's variable.
This method examines the `literal` of the output segment to see if it
contains a variable placeholder referencing the given node in the form:
{{ node_id.field_name }}
It uses a regular expression to match the exact node ID, avoiding
false positives from substring matches (e.g., 'node1' should not match 'node10').
Args:
node_id (str): The ID of the node to check for in this segment's variable placeholders.
Returns:
bool:
- True if the segment contains a variable referencing the given node.
- False otherwise.
Example:
literal = "{{node1.name}}"
depends_on_node("node1") -> True
depends_on_node("node2") -> False
Usage:
This method is primarily used in stream mode to determine whether
a particular variable output segment should be activated when a
specific upstream node completes execution.
"""
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
pattern = re.compile(variable_pattern)
match = pattern.search(self.literal)
if match:
return True
return False
class StreamOutputConfig(BaseModel):
"""
Streaming output configuration for an End node.
This structure controls:
- whether the End node output is globally active
- which upstream branch nodes are responsible for activation
- how each output segment behaves in streaming mode
"""
activate: bool = Field(
...,
description=(
"Global activation state of the End node output.\n"
"If False, no output should be emitted until all control nodes are resolved."
)
)
control_nodes: list[str] = Field(
...,
description=(
"List of upstream branch node IDs that control this End node.\n"
"Each node must signal completion before output becomes active."
)
)
outputs: list[OutputContent] = Field(
...,
description="Ordered list of output segments parsed from the output template."
)
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates how many output segments have already been emitted."
)
)
def update_activate(self, node_id):
"""
Update activation state based on an upstream node completion.
This method is typically called when a branch/control node finishes execution.
Behavior:
1. If the node is a control node:
- Remove it from `control_nodes`
- If all control nodes are resolved, activate the entire output
2. Activate variable output segments that depend on this node:
- If an output segment is a variable
- And its literal references the completed node_id
- Mark that segment as active
"""
# Case 1: resolve control branch dependency
if node_id in self.control_nodes:
self.control_nodes.remove(node_id)
# All branch constraints resolved → enable output
if not self.control_nodes:
self.activate = True
# Case 2: activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
and self.outputs[i].depends_on_node(node_id)
):
self.outputs[i].activate = True
class GraphBuilder:
def __init__(
self,
@@ -29,6 +179,12 @@ class GraphBuilder:
self.start_node_id = None
self.end_node_ids = []
self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
self._analyze_end_node_output()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
@@ -43,79 +199,182 @@ class GraphBuilder:
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""
Analyze the prefix configuration for End nodes.
def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID.
This function scans each End node's output template, identifies
references to its direct upstream nodes, and extracts the prefix
string appearing before the first reference.
Args:
node_id (str): The unique identifier of the node.
Returns:
tuple:
- dict[str, str]: Mapping from upstream node ID to its End node prefix
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
str: The type of the node.
Raises:
RuntimeError: If no node with the given `node_id` exists.
"""
import re
try:
return self.node_map[node_id]["type"]
except KeyError:
raise RuntimeError(f"Node not found: Id={node_id}")
prefixes = {}
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
"""Find upstream branch nodes for a given target node in the workflow graph.
# 找到所有 End 节点
This method identifies all upstream control (branch) nodes that can affect
the execution of `target_node`. If `target_node` is reachable from a start
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
and non-branch nodes, recursively traversing upstream through non-branch
nodes. If any non-branch upstream path does not lead to a branch node,
the result will indicate that no valid upstream branch node exists.
Args:
target_node (str): The identifier of the target node.
Returns:
tuple[bool, tuple[str]]:
- has_branch (bool): True if all upstream non-branch paths lead to at least
one branch node; False if any path reaches a start node without a branch.
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
"""
source_nodes = [
edge.get("source")
for edge in self.edges
if edge.get("target") == target_node
]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return False, tuple()
branch_nodes = []
non_branch_nodes = []
for node_id in source_nodes:
if self.get_node_type(node_id) in BRANCH_NODES:
branch_nodes.append(node_id)
else:
non_branch_nodes.append(node_id)
has_branch = True
for node_id in non_branch_nodes:
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
has_branch = has_branch and node_has_branch
if not has_branch:
break
branch_nodes.extend(nodes)
if not has_branch:
branch_nodes = []
return has_branch, tuple(set(branch_nodes))
def _analyze_end_node_output(self):
"""
Analyze output templates of all End nodes and generate StreamOutputConfig.
This method is responsible for parsing the `output` field of End nodes,
splitting literal text and variable placeholders (e.g. {{ node.field }}),
and determining whether each output segment should be activated immediately
or controlled by upstream branch nodes.
In stream mode:
- If the End node is controlled by any upstream branch node, the output
will be initially inactive and controlled by those branch nodes.
- Otherwise, the output is activated immediately.
In non-stream mode:
- All outputs are activated by default.
"""
# Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
# Iterate through each End node to analyze its output
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
config = end_node.get("config", {})
output = config.get("output")
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
if not output_template:
# Skip End nodes without output configuration
if not output:
continue
# Find all node references in the template
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Identify all direct upstream nodes connected to the End node
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
# Split output into ordered segments
output_template = list(re.findall(pattern, output))
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
# Determine whether each segment is literal text
# True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value)
output_flag = [
not bool(variable_pattern.match(item))
for item in output_template
]
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
# Stream mode: output activation depends on upstream branch nodes
if self.stream:
# Find upstream branch nodes that can control this End node
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
# Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig(
# If there is no upstream branch, output is active immediately
activate=not has_branch,
logger.info(f"[Prefix Analysis] "
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
# Branch nodes that control activation of this End node
control_nodes=list(control_nodes),
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
# Convert output segments into OutputContent objects
outputs=list(
[
OutputContent(
literal=output_string,
# Literal text can be activated immediately unless blocked by branch
activate=activate,
# Variable segments are marked explicitly
is_variable=not activate
)
for output_string, activate in zip(output_template, output_flag)
]
),
# Cursor for streaming output (initially 0)
cursor=0
)
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {not has_branch}, "
f"control_nodes: {control_nodes},"
f"output: {output_template},"
f"output_activate: {output_flag}")
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"[Prefix Analysis] "
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
# Non-stream mode: all outputs are activated by default
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
activate=True,
control_nodes=[],
outputs=list(
[
OutputContent(
literal=output_string,
activate=True,
is_variable=not activate
)
for output_string, activate in zip(output_template, output_flag)
]
),
cursor=0
)
def add_nodes(self):
"""Add all nodes from the workflow configuration to the state graph.
@@ -135,9 +394,6 @@ class GraphBuilder:
Returns:
None
"""
# Analyze End node prefixes if in stream mode
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
@@ -171,17 +427,6 @@ class GraphBuilder:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# Inject End node prefix configuration if in stream mode
if self.stream and node_id in end_prefixes:
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"Injected End prefix for node {node_id}")
# Mark nodes as adjacent and referenced to End node in stream mode
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
# Wrap node's run method to avoid closure issues
if self.stream:
# Stream mode: create an async generator function
@@ -261,6 +506,7 @@ class GraphBuilder:
for source_node, branches in conditional_edges.items():
def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets):
def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE

View File

@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
error: str | None
error_node: str | None
# Streaming buffer (stores real-time streaming output of nodes)
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# node activate status
activate: Annotated[dict[str, bool], merge_activate_state]
@@ -300,7 +296,7 @@ class BaseNode(ABC):
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"跳过节点{self.node_id}")
logger.info(f"jump node: {self.node_id}")
return
import time
@@ -313,19 +309,6 @@ class BaseNode(ABC):
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
@@ -340,66 +323,25 @@ class BaseNode(ABC):
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
if item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
else:
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
content = str(item.get("chunk"))
done = item.get("done", False)
chunks.append(content)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"type": "node_chunk",
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
"chunk": content,
"done": done
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
elapsed_time = time.time() - start_time
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
@@ -426,16 +368,6 @@ class BaseNode(ABC):
"looping": state["looping"]
}
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update | self.trans_activate(state)
@@ -544,6 +476,11 @@ class BaseNode(ABC):
"error_node": self.node_id
}
else:
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")

View File

@@ -0,0 +1,3 @@
from app.core.workflow.nodes.code.node import CodeNode
__all__ = ["CodeNode"]

View File

@@ -0,0 +1,50 @@
from typing import Literal
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
class InputVariable(BaseModel):
name: str = Field(
...,
description="variable name"
)
variable: str = Field(
...,
description="variable selector"
)
class OutputVariable(BaseModel):
name: str = Field(
...,
description="variable name"
)
type: VariableType = Field(
...,
description="variable selector"
)
class CodeNodeConfig(BaseNodeConfig):
input_variables: list[InputVariable] = Field(
default_factory=list,
description="input variables"
)
output_variables: list[OutputVariable] = Field(
default_factory=list,
description="output variables"
)
code: str = Field(
default="",
description="code content"
)
language: Literal['python3', 'nodejs'] = Field(
...,
description="language"
)

View File

@@ -0,0 +1,121 @@
import base64
import json
import logging
import re
from string import Template
from textwrap import dedent
from typing import Any
import httpx
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.code.config import CodeNodeConfig
logger = logging.getLogger(__name__)
SCRIPT_TEMPLATE = Template(dedent("""
$code
import json
from base64 import b64decode
# decode and prepare input dict
inputs_obj = json.loads(b64decode('$inputs_variable').decode('utf-8'))
# execute main function
output_obj = main(**inputs_obj)
# convert output to json and print
output_json = json.dumps(output_obj, indent=4)
result = "<<RESULT>>" + output_json + "<<RESULT>>"
print(result)
"""))
class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: CodeNodeConfig | None = None
def extract_result(self, content: str):
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
if match:
extracted = match.group(1)
exec_result = json.loads(extracted)
result = {}
for output in self.typed_config.output_variables:
value = exec_result.get(output.name)
if value is None:
raise RuntimeError(f"Return value {output.name} does not exist")
match output.type:
case VariableType.STRING:
if not isinstance(value, str):
raise RuntimeError(f"Return value {output.name} should be a string")
case VariableType.BOOLEAN:
if not isinstance(value, bool):
raise RuntimeError(f"Return value {output.name} should be a boolean")
case VariableType.NUMBER:
if not isinstance(value, (int, float)):
raise RuntimeError(f"Return value {output.name} should be a number")
case VariableType.OBJECT:
if not isinstance(value, dict):
raise RuntimeError(f"Return value {output.name} should be a dictionary")
case VariableType.ARRAY_STRING:
if not isinstance(value, list) or not all(isinstance(v, str) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of strings")
case VariableType.ARRAY_NUMBER:
if not isinstance(value, list) or not all(isinstance(v, (int, float)) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of numbers")
case VariableType.ARRAY_OBJECT:
if not isinstance(value, list) or not all(isinstance(v, dict) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of dictionaries")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(value, list) or not all(isinstance(v, bool) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of booleans")
result[output.name] = value
return result
else:
raise RuntimeError("The output of main must be a dictionary")
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = CodeNodeConfig(**self.config)
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")
).decode("utf-8")
final_script = SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
async with httpx.AsyncClient() as client:
response = await client.post(
"http://sandbox:8194/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},
json={
"language": self.typed_config.language,
"code": base64.b64encode(final_script.encode("utf-8")).decode("utf-8"),
"options": {
"enable_network": True
}
}
)
resp = response.json()
match resp['code']:
case 31:
raise RuntimeError("Operation not permitted")
case 0:
return self.extract_result(resp["data"]["stdout"])
case _:
raise Exception(resp["message"])

View File

@@ -10,21 +10,22 @@ from app.core.workflow.nodes.base_config import (
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
@@ -49,5 +50,6 @@ __all__ = [
"QuestionClassifierNodeConfig",
"ToolNodeConfig",
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig"
"MemoryWriteNodeConfig",
"CodeNodeConfig"
]

View File

@@ -1,5 +1,4 @@
import asyncio
import copy
import logging
import re
from typing import Any

View File

@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType

View File

@@ -5,10 +5,8 @@ End 节点实现
"""
import logging
import re
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
@@ -37,24 +35,8 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成"
output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
@@ -63,274 +45,3 @@ class EndNode(BaseNode):
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""Execute End node business logic (streaming)
Smart output strategy:
1. Check if template references a direct upstream LLM node
2. If yes, only output the part AFTER that reference (suffix)
3. Prefix and LLM content have already been sent during LLM node streaming
Note: Only LLM nodes get this special treatment. Other node types output normally.
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- Direct upstream LLM node is llm_qa
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args:
state: Workflow state
Yields:
Completion marker
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
if not output_template:
output = "工作流已完成"
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": "",
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output}
return
# Find direct upstream LLM nodes
direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# Parse template parts
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}")
# Find the first reference to a direct upstream LLM node
upstream_llm_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state, strict=False)
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": output,
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker
yield {"__final__": True, "result": output}
return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts
suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static":
# 静态文本
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# Other dynamic references (if there are multiple references)
node_id = part["node_id"]
field = part["field"]
# Use VariablePool to get variable value
pool = self.get_variable_pool(state)
try:
# Try to get variable value with default empty string
content = pool.get([node_id, field], default="")
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
except Exception as e:
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
content = ""
# Convert to string if not None
suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None= None
self.typed_config: IfElseNodeConfig | None = None
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:

View File

@@ -7,18 +7,18 @@ LLM 节点实现
import logging
import re
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from langchain_core.messages import AIMessage
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
@@ -231,42 +231,14 @@ class LLMNode(BaseNode):
文本片段chunk或完成标记
"""
self.typed_config = LLMNodeConfig(**self.config)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
@@ -284,12 +256,19 @@ class LLMNode(BaseNode):
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
yield {
"__final__": False,
"chunk": content
}
yield {
"__final__": False,
"chunk": "",
"done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据

View File

@@ -1,7 +1,6 @@
import uuid
from uuid import UUID
from pydantic import Field
from typing import Literal
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -11,7 +10,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
...
)
config_id: int = Field(
config_id: UUID | int = Field(
...
)
@@ -26,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
...
)
config_id: int = Field(
config_id: UUID = Field(
...
)

Some files were not shown because too many files have changed in this diff Show More