Compare commits
383 Commits
release/v0
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7eaf563d7 | ||
|
|
4c7809ce4a | ||
|
|
51847955cd | ||
|
|
52260f469a | ||
|
|
c566d22836 | ||
|
|
62c557deae | ||
|
|
677a603835 | ||
|
|
447d8790ad | ||
|
|
7b3bf41120 | ||
|
|
fe3c31c08c | ||
|
|
4e7ab3d7e3 | ||
|
|
47b25d7a26 | ||
|
|
aca7d25001 | ||
|
|
dfa7a2d4cf | ||
|
|
169e01276d | ||
|
|
07e698265e | ||
|
|
46ed7e38bf | ||
|
|
3364374dc6 | ||
|
|
f125d11b6d | ||
|
|
657d48a5f9 | ||
|
|
3735bdde19 | ||
|
|
3f906d81cb | ||
|
|
7c1f622797 | ||
|
|
c96df6bfa5 | ||
|
|
9e6e8f50f8 | ||
|
|
88b89ef315 | ||
|
|
cc1528f550 | ||
|
|
1c8a83140b | ||
|
|
34276e2066 | ||
|
|
918e7285c4 | ||
|
|
056d422c71 | ||
|
|
5ee54f4e0e | ||
|
|
260c75e70c | ||
|
|
2d7401922f | ||
|
|
8c7a1348cf | ||
|
|
24fbdbd716 | ||
|
|
aad8f0e36b | ||
|
|
15cad44f08 | ||
|
|
0271454671 | ||
|
|
d0ddf288ca | ||
|
|
bc250ac377 | ||
|
|
7922fc3b0e | ||
|
|
514c19a247 | ||
|
|
41550d4a41 | ||
|
|
33cc3c1c3f | ||
|
|
8f0a1d9c6e | ||
|
|
72b5e5cf8e | ||
|
|
62aba2dd38 | ||
|
|
cdd6b80089 | ||
|
|
333836f5e7 | ||
|
|
2d28b4b05c | ||
|
|
48aca996ff | ||
|
|
c8c7e9b304 | ||
|
|
97ff023995 | ||
|
|
34f0c3b90c | ||
|
|
7c2902d2b8 | ||
|
|
8e41afdffc | ||
|
|
7268886294 | ||
|
|
cbae900866 | ||
|
|
ffff138a6f | ||
|
|
88c95db8d0 | ||
|
|
bc36b79105 | ||
|
|
5694bc0230 | ||
|
|
ead3080b2b | ||
|
|
21eae29bb7 | ||
|
|
406740b524 | ||
|
|
9d30bc4062 | ||
|
|
fad91b64ab | ||
|
|
2132e71a81 | ||
|
|
24dafa7359 | ||
|
|
9a3c74fb64 | ||
|
|
f571f0688a | ||
|
|
3efb3e8a35 | ||
|
|
cfcb278406 | ||
|
|
dc0d34c281 | ||
|
|
72076c218f | ||
|
|
151fd3b950 | ||
|
|
f27de7df35 | ||
|
|
9f2b6390b0 | ||
|
|
e196f86e30 | ||
|
|
ec41d45234 | ||
|
|
567d1ba18b | ||
|
|
5e47fc45ab | ||
|
|
b471d56a86 | ||
|
|
61f8029205 | ||
|
|
1aff4eda67 | ||
|
|
5d5351f0bc | ||
|
|
1224802ac6 | ||
|
|
3464573f17 | ||
|
|
9cf49c9c75 | ||
|
|
4e837cb90c | ||
|
|
e4fb58496b | ||
|
|
15a254c0cd | ||
|
|
d62746fc8c | ||
|
|
4b8b6fe407 | ||
|
|
6754834eb3 | ||
|
|
be98db561d | ||
|
|
574d0afc72 | ||
|
|
31c8ad611c | ||
|
|
b23730388d | ||
|
|
1b853aa893 | ||
|
|
36cb0a12ad | ||
|
|
5439eacf2d | ||
|
|
2687c3b80e | ||
|
|
fa009327ad | ||
|
|
838bd46e83 | ||
|
|
ccc2009aa8 | ||
|
|
d9aba92314 | ||
|
|
696b0475a8 | ||
|
|
e7370489e8 | ||
|
|
f1503b2238 | ||
|
|
cd4661e878 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
22151eb49b | ||
|
|
d0354345f6 | ||
|
|
b1e61eb1e4 | ||
|
|
36e0ed15b6 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
504d87b0b0 | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
a651ae6ed4 | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
9206c7642a | ||
|
|
d1b4f2b6c2 | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
cca3900678 | ||
|
|
4fe32b7dbc | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
f3da8956d9 | ||
|
|
b1147d77af | ||
|
|
66bc2fb41f | ||
|
|
4e538a6df8 | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
9c3e0b5541 | ||
|
|
33bfe33eb3 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
7e56c09620 | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
2e7f6afe3f | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
9a4b1f0937 | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
e5e914903c | ||
|
|
7ba443afa5 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
2862db3534 | ||
|
|
bf3e30dac0 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
375660f232 | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
f6031baee4 | ||
|
|
75b3ea1f05 | ||
|
|
c818ba7bc7 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
8fb9e779a6 | ||
|
|
c5a794f1b5 | ||
|
|
3aa2cdd754 | ||
|
|
d93d52cf10 | ||
|
|
2abbd5a7fb | ||
|
|
2a10e9f7ee | ||
|
|
166d05afe9 | ||
|
|
2eff8d1962 | ||
|
|
93c9e76c4b | ||
|
|
021cb09b82 | ||
|
|
28e6939884 | ||
|
|
8847039d76 | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
2694576a32 | ||
|
|
e4f10670f6 | ||
|
|
1324ba3a49 | ||
|
|
73c7810310 | ||
|
|
d160076267 | ||
|
|
a53be31765 | ||
|
|
ed8c1c7c19 | ||
|
|
159c8d1ff9 | ||
|
|
8932d455d8 | ||
|
|
3af183f6c3 | ||
|
|
4475be51cc | ||
|
|
c3ea3b751b | ||
|
|
e2c67d0c5b | ||
|
|
87731090ca | ||
|
|
80ca247435 | ||
|
|
a5b8d3afa5 | ||
|
|
1f615a06ad | ||
|
|
4123560a98 | ||
|
|
5267bd60a5 | ||
|
|
f76bffb482 | ||
|
|
51185c83c9 | ||
|
|
f1f887faae | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
46f0f3cee9 | ||
|
|
e1f5607836 | ||
|
|
ebc41b2eec | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
3b4b474ce8 | ||
|
|
4534e46811 | ||
|
|
7bfa7b3f02 | ||
|
|
1cc34d8e62 | ||
|
|
2eff6b2e9d | ||
|
|
b046411302 | ||
|
|
6ab65b3626 | ||
|
|
cf321f9b09 | ||
|
|
8228d38859 | ||
|
|
c2e3110fa2 | ||
|
|
85681db7b7 | ||
|
|
1fc04c37d3 | ||
|
|
0fd8a122fb | ||
|
|
e3b6ede992 | ||
|
|
3601737869 | ||
|
|
9de6b4f151 | ||
|
|
4f4f55d67f | ||
|
|
714c624dc6 | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
1919580759 | ||
|
|
b27ffe57e6 | ||
|
|
c115bcde54 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
313f19eba4 | ||
|
|
c6bcf53fea | ||
|
|
86812b34d1 | ||
|
|
15f9c49418 | ||
|
|
6e18c92a13 | ||
|
|
7870c6c33f | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
45adb9627a | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
7219274d94 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
51680b7077 | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
bcc8b7ce3c | ||
|
|
1c7fe6d134 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
c93bcb8678 | ||
|
|
98b2da9123 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
5a0d3df689 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
6ac10a8297 | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ nltk_data/
|
||||
tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_python/target
|
||||
sandbox/lib/seccomp_nodejs/target
|
||||
|
||||
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
@@ -3,9 +3,14 @@ import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
@@ -63,15 +68,20 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -79,40 +89,40 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"regenerate-memory-cache": {
|
||||
"task": "app.tasks.regenerate_memory_cache",
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
}
|
||||
# beat_schedule_config = {
|
||||
# "run-workspace-reflection": {
|
||||
# "task": "app.tasks.workspace_reflection_task",
|
||||
# "schedule": workspace_reflection_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "regenerate-memory-cache": {
|
||||
# "task": "app.tasks.regenerate_memory_cache",
|
||||
# "schedule": memory_cache_regeneration_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "run-forgetting-cycle": {
|
||||
# "task": "app.tasks.run_forgetting_cycle_task",
|
||||
# "schedule": forgetting_cycle_schedule,
|
||||
# "kwargs": {
|
||||
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
# if settings.DEFAULT_WORKSPACE_ID:
|
||||
# beat_schedule_config["write-total-memory"] = {
|
||||
# "task": "app.controllers.memory_storage_controller.search_all",
|
||||
# "schedule": memory_increment_schedule,
|
||||
# "kwargs": {
|
||||
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
# },
|
||||
# }
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
# celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -45,6 +45,7 @@ from . import (
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
ontology_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -90,5 +91,6 @@ manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -872,3 +872,44 @@ async def update_workflow_config(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
app_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
- daily_new_users: 每日新增用户数
|
||||
- total_new_users: 总新增用户数
|
||||
- daily_api_calls: 每日API调用次数
|
||||
- total_api_calls: 总API调用次数
|
||||
- daily_tokens: 每日token消耗
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -7,11 +7,13 @@ Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -20,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -32,11 +35,11 @@ router = APIRouter(
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -45,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -78,7 +81,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
@@ -157,6 +160,7 @@ def update_emotion_config(
|
||||
}
|
||||
}
|
||||
"""
|
||||
config.config_id=resolve_config_id(config.config_id, db)
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
|
||||
@@ -53,7 +53,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
@@ -63,7 +63,7 @@ async def get_emotion_tags(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
@@ -73,7 +73,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
@@ -84,7 +84,7 @@ async def get_emotion_tags(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -159,21 +159,21 @@ async def get_emotion_health(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
@@ -186,7 +186,7 @@ async def get_emotion_health(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
||||
extra={"group_id": request.group_id}
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -310,7 +310,7 @@ async def get_file_url(
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
|
||||
@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
@@ -137,7 +137,7 @@ async def get_preference_tags(
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
@@ -146,20 +146,20 @@ async def get_preference_tags(
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -192,17 +192,17 @@ async def get_preference_tags(
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
@@ -309,7 +309,7 @@ async def get_behavior_habits(
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
@@ -317,20 +317,20 @@ async def get_behavior_habits(
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -368,11 +368,11 @@ async def get_behavior_habits(
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ async def write_server(
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
@@ -160,19 +160,18 @@ async def write_server(
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
messages_list, # 传递结构化消息列表
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -196,7 +195,7 @@ async def write_server_async(
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
@@ -226,10 +225,10 @@ async def write_server_async(
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -255,16 +254,14 @@ async def read_server(
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
user_input: Read request with message, history, search_switch, and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -279,12 +276,13 @@ async def read_server(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
@@ -295,17 +293,20 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -403,7 +404,7 @@ async def read_server_async(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
@@ -447,7 +448,7 @@ async def get_read_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -524,7 +525,7 @@ async def get_write_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -578,16 +579,16 @@ async def status_type(
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
user_input: Request containing user message and end_user_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
@@ -595,11 +596,11 @@ async def status_type(
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
@@ -624,7 +625,7 @@ async def get_knowledge_type_stats_api(
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
@@ -697,7 +698,7 @@ async def get_user_profile_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取工作空间下Popular Memory Tags,包含:
|
||||
获取用户详情,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
@@ -49,63 +49,134 @@ async def get_workspace_end_users(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
||||
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
memory_configs_map = {}
|
||||
if end_user_ids:
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
get_end_users_connected_configs_batch,
|
||||
end_user_ids, db
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
# 失败时使用空字典,不影响其他数据返回
|
||||
return {}
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
# RAG 模式:批量查询
|
||||
try:
|
||||
chunk_map = await asyncio.to_thread(
|
||||
memory_dashboard_service.get_users_total_chunk_batch,
|
||||
end_user_ids, db, current_user
|
||||
)
|
||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
|
||||
# 从批量查询结果中获取配置信息
|
||||
user_id = str(end_user.id)
|
||||
memory_config_info = memory_configs_map.get(user_id, {
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None
|
||||
})
|
||||
|
||||
# 只保留需要的字段,移除 error 字段(如果有)
|
||||
memory_config = {
|
||||
"memory_config_id": memory_config_info.get("memory_config_id"),
|
||||
"memory_config_name": memory_config_info.get("memory_config_name")
|
||||
}
|
||||
|
||||
result.append(
|
||||
{
|
||||
'end_user': end_user,
|
||||
'memory_num': memory_num,
|
||||
'memory_config': memory_config
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
},
|
||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||
'memory_config': {
|
||||
"memory_config_id": config_info.get("memory_config_id"),
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -33,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -83,7 +84,8 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
config_id = resolve_config_id((config_id), db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
@@ -106,7 +108,7 @@ async def trigger_forgetting_cycle(
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
@@ -128,7 +130,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -157,6 +159,7 @@ async def read_forgetting_config(
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
@@ -194,6 +197,8 @@ async def update_forgetting_config(
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -236,7 +241,7 @@ async def update_forgetting_config(
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -246,7 +251,7 @@ async def get_forgetting_stats(
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
end_user_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
@@ -254,26 +259,25 @@ async def get_forgetting_stats(
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
@@ -283,14 +287,14 @@ async def get_forgetting_stats(
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
@@ -324,7 +328,7 @@ async def get_forgetting_curve(
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
request.config_id = resolve_config_id((request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
|
||||
@@ -27,27 +27,27 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
count_stats = service.get_memory_count(end_user_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
@@ -57,37 +57,37 @@ def get_memory_count(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
@@ -101,37 +101,37 @@ def get_last_visual_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
@@ -145,38 +145,38 @@ def get_last_memory_listen(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
text_memory = service.get_latest_text_memory(end_user_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
@@ -190,16 +190,16 @@ def get_last_text_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
@@ -209,7 +209,7 @@ def get_memory_time_line(
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
@@ -221,7 +221,7 @@ def get_memory_time_line(
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -232,7 +232,7 @@ def get_memory_time_line(
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
timeline_data = service.get_time_line(end_user_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
@@ -246,7 +246,7 @@ def get_memory_time_line(
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
@@ -11,7 +12,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
@@ -24,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -42,15 +45,15 @@ async def save_reflection_config(
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
data_config = DataConfigRepository.update_reflection_config(
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
@@ -63,17 +66,17 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(data_config)
|
||||
db.refresh(memory_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": data_config.config_id,
|
||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||
"iteration_period": data_config.iteration_period,
|
||||
"reflexion_range": data_config.reflexion_range,
|
||||
"baseline": data_config.baseline,
|
||||
"reflection_model_id": data_config.reflection_model_id,
|
||||
"memory_verify": data_config.memory_verify,
|
||||
"quality_assessment": data_config.quality_assessment}
|
||||
"config_id": memory_config.config_id,
|
||||
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||
"iteration_period": memory_config.iteration_period,
|
||||
"reflexion_range": memory_config.reflexion_range,
|
||||
"baseline": memory_config.baseline,
|
||||
"reflection_model_id": memory_config.reflection_model_id,
|
||||
"memory_verify": memory_config.memory_verify,
|
||||
"quality_assessment": memory_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -98,7 +101,7 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
@@ -107,42 +110,55 @@ async def start_workspace_reflection(
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
# 跳过没有配置的应用
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# 为每个用户执行反思
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
try:
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
@@ -156,17 +172,20 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
@@ -191,7 +210,7 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -199,9 +218,9 @@ async def reflection_run(
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -34,6 +35,8 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -140,7 +143,6 @@ def create_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
@@ -160,12 +162,12 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
@@ -187,12 +189,17 @@ def update_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -210,7 +217,7 @@ def update_config_extracted(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
@@ -232,12 +239,12 @@ def update_config_extracted(
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
@@ -285,6 +292,7 @@ async def pilot_run(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
@@ -420,15 +428,95 @@ async def get_hot_memory_tags_api(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
缓存策略:
|
||||
- 缓存键:workspace_id + limit
|
||||
- 过期时间:5分钟(300秒)
|
||||
- 缓存命中:~50ms
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
import json
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
try:
|
||||
data = json.loads(cached_result)
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
cache_data = json.dumps(result, ensure_ascii=False)
|
||||
await aio_redis_set(cache_key, cache_data, expire=300)
|
||||
api_logger.info(f"Cached result for key: {cache_key}")
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
用于:
|
||||
- 手动刷新数据
|
||||
- 调试和测试
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
result = await aio_redis_delete(cache_key)
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -20,18 +20,18 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -39,7 +39,7 @@ def get_conversations(
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
end_user_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_conversations(
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
end_user_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
@@ -63,7 +63,7 @@ def get_conversations(
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -100,7 +100,7 @@ def get_messages(
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -24,24 +26,83 @@ router = APIRouter(
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
return success(msg="获取模型提供商成功", data=providers)
|
||||
|
||||
@router.get("/strategy", response_model=ApiResponse)
|
||||
def get_model_strategies():
|
||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(
|
||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list_new(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -53,36 +114,127 @@ def get_model_list(
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
api_logger.info(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQueryNew(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
is_composite=is_composite,
|
||||
search=search
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/model_plaza", response_model=ApiResponse)
|
||||
def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""模型广场查询接口(按供应商分组)"""
|
||||
|
||||
query = model_schema.ModelBaseQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_official=is_official,
|
||||
is_deprecated=is_deprecated,
|
||||
search=search
|
||||
)
|
||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
return success(data=result, msg="模型广场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def get_model_base_by_id(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取基础模型详情"""
|
||||
|
||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza", response_model=ApiResponse)
|
||||
def create_model_base(
|
||||
data: model_schema.ModelBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建基础模型"""
|
||||
|
||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||
|
||||
|
||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def update_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
data: model_schema.ModelBaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
# 不允许更改type类型
|
||||
if data.type is not None or data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
|
||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def delete_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除基础模型"""
|
||||
|
||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||
return success(msg="基础模型删除成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||
def add_model_from_plaza(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从模型广场添加模型到模型列表"""
|
||||
|
||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
@@ -138,6 +290,73 @@ async def create_model(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建组合模型
|
||||
|
||||
- 绑定一个或多个现有的 API Key
|
||||
- 所有 API Key 必须来自非组合模型
|
||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||
"""
|
||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新组合模型"""
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||
def delete_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除组合模型"""
|
||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||
return success(msg="组合模型删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
@@ -214,6 +433,53 @@ def get_model_api_keys(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||
async def create_model_api_key_by_provider(
|
||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据供应商为所有匹配的模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 根据tenant_id和provider筛选model_config_id列表
|
||||
model_config_ids = api_key_data.model_config_ids
|
||||
if not model_config_ids:
|
||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
provider=api_key_data.provider
|
||||
)
|
||||
|
||||
if not model_config_ids:
|
||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 构造schema并调用service
|
||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
description=api_key_data.description,
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
@@ -228,11 +494,12 @@ async def create_model_api_key(
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
api_key_data.model_config_ids = [model_id]
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
@@ -334,5 +601,3 @@ async def validate_model_config(
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
1005
api/app/controllers/ontology_controller.py
Normal file
1005
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
611
api/app/controllers/ontology_secondary_routes.py
Normal file
611
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
class_data = classes_data[0]
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -135,3 +139,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -317,9 +317,12 @@ async def chat(
|
||||
appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
|
||||
@@ -235,11 +235,11 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
@@ -268,11 +268,11 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -135,27 +135,27 @@ async def generate_cache_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
end_user_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||
)
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
@@ -175,9 +175,9 @@ async def generate_cache_api(
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ async def create_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -214,7 +214,7 @@ async def delete_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -259,7 +259,7 @@ async def validate_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -329,7 +329,7 @@ async def get_workflow_executions(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -389,7 +389,7 @@ async def get_workflow_execution(
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -440,7 +440,7 @@ async def run_workflow(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
|
||||
@@ -7,27 +7,21 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -104,7 +98,7 @@ class LangChainAgent:
|
||||
"streaming": streaming,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -143,100 +137,7 @@ class LangChainAgent:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -281,30 +182,6 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -325,17 +202,17 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -343,7 +220,7 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,25 +280,7 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
@@ -437,7 +296,7 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
@@ -474,12 +333,17 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -9,6 +9,25 @@ load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
# ========================================================================
|
||||
# Deployment Mode Configuration
|
||||
# ========================================================================
|
||||
# community: 社区版(开源,功能受限)
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -72,6 +91,10 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
@@ -107,6 +130,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -133,6 +157,11 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -184,7 +213,7 @@ class Settings:
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
kb_config = await rag_config(state)
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
group_id=state.get('group_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取group_id
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
group_id = state.get("group_id", '')
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
group_id = state.get("group_id", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=group_id,
|
||||
user_id=end_user_id,
|
||||
query=data,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
apply_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
group_id=state.get("group_id", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, and memory_config
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
@@ -79,7 +79,7 @@ async def make_read_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
group_id = '88a459f5_text09' # 组ID
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
@@ -95,9 +95,9 @@ async def main():
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(group_id: str):
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_group_id = group_id_param or group_id
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
group_id=actual_group_id,
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含group_id, limit, include等
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
@@ -211,7 +212,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -224,7 +225,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
@@ -238,7 +239,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id or search_params.get("group_id"),
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
@@ -291,7 +292,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
@@ -301,7 +302,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
@@ -311,7 +312,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,21 +1,20 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -26,8 +25,12 @@ async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
@@ -38,43 +41,63 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
group_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
|
||||
@@ -91,7 +91,7 @@ class SearchService:
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
@@ -105,7 +105,7 @@ class SearchService:
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
end_user_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
@@ -130,7 +130,7 @@ class SearchService:
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
@@ -186,7 +186,7 @@ class SearchService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
|
||||
"""
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
return [dialog_data]
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data:str
|
||||
data: str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
@@ -28,7 +27,7 @@ class ReadState(TypedDict):
|
||||
messages: 消息列表,支持自动追加
|
||||
loop_count: 遍历次数
|
||||
search_switch: 搜索类型开关
|
||||
group_id: 组标识
|
||||
end_user_id: 组标识
|
||||
config_id: 配置ID,用于过滤结果
|
||||
data: 从content_input_node传递的内容数据
|
||||
spit_data: 从Split_The_Problem传递的分解结果
|
||||
@@ -39,7 +38,7 @@ class ReadState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||
loop_count: int
|
||||
search_switch: str
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# 角色
|
||||
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
# 用户问题
|
||||
{{query}}
|
||||
# 回答指南
|
||||
## 1. 仔细阅读检索信息
|
||||
- 答案可能直接或间接地出现在检索信息中
|
||||
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
|
||||
- 第三人称描述的偏好、行为通常指用户本人
|
||||
|
||||
## 2. 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
## 3. 回答要求
|
||||
- 像人类对话一样自然流畅
|
||||
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||
- 不要解释推理过程或引用信息来源
|
||||
- 保持友好、乐于助人的语气
|
||||
- 使用与问题相同的语言回答
|
||||
# 关键示例
|
||||
**示例1 - 直接匹配:**
|
||||
- 检索信息:"小曼会使用Python..."
|
||||
- 问题:"我叫什么"
|
||||
- ✓ 正确:"你叫小曼"
|
||||
- ✗ 错误:"你没有告诉我你的名字"
|
||||
**示例2 - 间接匹配:**
|
||||
- 检索信息:"用户很喜欢吃星巴克的甜品"
|
||||
- 问题:"我喜欢什么"
|
||||
- ✓ 正确:"你很喜欢吃星巴克的甜品"
|
||||
- ✗ 错误:"信息不足"
|
||||
**示例3 - 信息不匹配(推荐做法):**
|
||||
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
|
||||
- 问题:"我吃过哪家面包"
|
||||
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
|
||||
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
|
||||
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
|
||||
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
|
||||
# 重要提醒
|
||||
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
|
||||
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
|
||||
- 用对话式语言表达"不知道",而非机械模板
|
||||
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆
|
||||
@@ -0,0 +1,43 @@
|
||||
{# 角色定义 #}
|
||||
你是专业的问题解答专家+引导学者
|
||||
|
||||
{# 输入数据展示 #}
|
||||
{% if data %}
|
||||
## 输入数据
|
||||
上下文信息:
|
||||
{% for item in data.history %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
检索到的所有信息:
|
||||
{% for item in data.retrieve_info %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## User Query
|
||||
{{ query }}
|
||||
|
||||
{# 问题回答标准 #}
|
||||
## 问题回答核心标准
|
||||
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
|
||||
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
|
||||
{# 重要提醒 #}
|
||||
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
|
||||
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,36 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,210 +41,633 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"group_id": session.get('group_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
"""
|
||||
读取一条会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取所有会话数据
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
"""
|
||||
删除单条会话
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
"""
|
||||
删除所有会话
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
group_id = data.get('group_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -29,20 +30,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
@@ -51,14 +50,14 @@ async def write(
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
logger.info(f"end_user_id ID: {end_user_id}")
|
||||
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
@@ -83,9 +82,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
@@ -127,23 +124,48 @@ async def write(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_default_docs_path() -> str:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
from pathlib import Path
|
||||
project_root = str(Path(__file__).resolve().parents[2])
|
||||
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
end_user_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if not config_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
|
||||
else:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
id=end_user_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
# 验证end_user_id必须提供且不为空
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||
end_user_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
|
||||
@@ -2,13 +2,16 @@ import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
||||
except Exception:
|
||||
# Fallback: derive project root from this file location
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
|
||||
# 需要向上 5 级到达 api/ 目录
|
||||
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
|
||||
|
||||
|
||||
def _get_latest_prompt_log_path() -> str | None:
|
||||
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
triplet_relations_count = 0
|
||||
temporal_count = 0
|
||||
|
||||
# Patterns
|
||||
# 正则表达式模式 - 匹配当前日志格式
|
||||
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
||||
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
|
||||
pat_triplet_done = re.compile(
|
||||
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
|
||||
pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
|
||||
pat_triplet_completed = re.compile(
|
||||
r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
|
||||
)
|
||||
pat_temporal_done = re.compile(
|
||||
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
|
||||
pat_temporal_completed = re.compile(
|
||||
r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
|
||||
)
|
||||
|
||||
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
for line in f:
|
||||
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
|
||||
# 文本块数量(每个块触发一次陈述提取提示)
|
||||
if pat_chunk_render.search(line):
|
||||
chunk_count += 1
|
||||
continue
|
||||
|
||||
m1 = pat_triplet_start.search(line)
|
||||
if m1:
|
||||
# 陈述数量(每个 Triplet Started 代表一个陈述被处理)
|
||||
if pat_triplet_started.search(line):
|
||||
statements_count += 1
|
||||
continue
|
||||
|
||||
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
|
||||
m_triplet = pat_triplet_completed.search(line)
|
||||
if m_triplet:
|
||||
try:
|
||||
statements_count += int(m1.group(1))
|
||||
triplet_relations_count += int(m_triplet.group(1))
|
||||
triplet_entities_count += int(m_triplet.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m2 = pat_triplet_done.search(line)
|
||||
if m2:
|
||||
# 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
|
||||
m_temporal = pat_temporal_completed.search(line)
|
||||
if m_temporal:
|
||||
try:
|
||||
triplet_relations_count += int(m2.group(1))
|
||||
triplet_entities_count += int(m2.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m3 = pat_temporal_done.search(line)
|
||||
if m3:
|
||||
try:
|
||||
temporal_count += int(m3.group(1))
|
||||
temporal_count += int(m_temporal.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
|
||||
|
||||
def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"""Get aggregated stats from all prompt logs in logs/.
|
||||
"""Get stats from the latest prompt log file only.
|
||||
|
||||
Returns (stats_dict, message).
|
||||
"""
|
||||
all_logs = _get_all_prompt_logs()
|
||||
# Fallback to recursive search if none found in logs/
|
||||
if not all_logs:
|
||||
# 获取最新的日志文件
|
||||
latest_log = _get_latest_prompt_log_path()
|
||||
|
||||
# 如果没有找到,尝试递归搜索
|
||||
if not latest_log:
|
||||
all_logs = _get_any_logs_recursive()
|
||||
if not all_logs:
|
||||
if all_logs:
|
||||
latest_log = all_logs[-1] # 取最新的
|
||||
|
||||
if not latest_log:
|
||||
return (
|
||||
{
|
||||
"chunk_count": 0,
|
||||
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"未找到日志文件,请确认已运行过提取流程。",
|
||||
)
|
||||
|
||||
agg = {
|
||||
"chunk_count": 0,
|
||||
"statements_count": 0,
|
||||
"triplet_entities_count": 0,
|
||||
"triplet_relations_count": 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
for path in all_logs:
|
||||
s = parse_stats_from_log(path)
|
||||
agg["chunk_count"] += s.get("chunk_count", 0)
|
||||
agg["statements_count"] += s.get("statements_count", 0)
|
||||
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
|
||||
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
|
||||
agg["temporal_count"] += s.get("temporal_count", 0)
|
||||
|
||||
# Attach a summary of files combined
|
||||
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
|
||||
return agg, "成功汇总 logs 目录中所有提示日志。"
|
||||
# 只解析最新的日志文件
|
||||
stats = parse_stats_from_log(latest_log)
|
||||
|
||||
# 添加日志文件路径信息
|
||||
stats["log_path"] = f"最新:{latest_log}"
|
||||
|
||||
return stats, "成功读取最近一次记忆活动统计。"
|
||||
|
||||
|
||||
def _format_summary(stats: dict) -> str:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
||||
@@ -1,30 +0,0 @@
|
||||
⏬数据集下载地址:
|
||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
||||
|
||||
全流程基准测试运行:
|
||||
locomo:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
||||
LongMemEval:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
||||
memsciqa:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
||||
|
||||
单独检索评估运行命令:
|
||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
||||
需要先在项目中修改需要检测评估的group_id。
|
||||
|
||||
参数及解释:
|
||||
● --dataset longmemeval - 指定数据集
|
||||
● --sample-size 10 - 评估10个样本
|
||||
● --start-index 0 - 从第0个样本开始
|
||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
||||
● --search-limit 8 - 检索限制8条
|
||||
● --context-char-budget 4000 - 上下文字符预算4000
|
||||
● --search-type hybrid - 使用混合检索
|
||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
||||
● --reset-group - 运行前清空组数据
|
||||
@@ -1,100 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = [t for t in text.split() if t]
|
||||
return tokens
|
||||
|
||||
|
||||
def exact_match(pred: str, ref: str) -> float:
|
||||
return float(_normalize(pred) == _normalize(ref))
|
||||
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
p = set(_normalize(pred))
|
||||
r = set(_normalize(ref))
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
# Clipped count
|
||||
r_counts: Dict[str, int] = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
clipped = 0
|
||||
p_counts: Dict[str, int] = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
# Brevity penalty
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
return bp * precision
|
||||
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
vals = sorted(values)
|
||||
k = (len(vals) - 1) * p
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return vals[int(k)]
|
||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
||||
|
||||
|
||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
||||
if not latencies_ms:
|
||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
||||
p25 = percentile(latencies_ms, 0.25)
|
||||
p50 = percentile(latencies_ms, 0.50)
|
||||
p75 = percentile(latencies_ms, 0.75)
|
||||
p95 = percentile(latencies_ms, 0.95)
|
||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
||||
|
||||
|
||||
def avg_context_tokens(contexts: List[str]) -> float:
|
||||
if not contexts:
|
||||
return 0.0
|
||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Dialogue search queries for evaluation purposes.
|
||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
# Chunk search queries
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.content CONTAINS $content
|
||||
RETURN c
|
||||
"""
|
||||
|
||||
# Dialogue search queries
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_id = $dialog_id
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.content CONTAINS $q
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -1,341 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
group_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
) -> bool:
|
||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
||||
|
||||
# Initialize llm client with graceful fallback
|
||||
llm_client = None
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
|
||||
# Step A: Build DialogData list from contexts with robust parsing
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
for idx, ctx in enumerate(contexts):
|
||||
messages: List[ConversationMessage] = []
|
||||
|
||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
||||
|
||||
if matches:
|
||||
for m in matches:
|
||||
raw_role = m.group(1).strip()
|
||||
content = m.group(2).strip()
|
||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
||||
else:
|
||||
# Fallback: line-by-line parsing
|
||||
for raw in ctx.split("\n"):
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
||||
else:
|
||||
# Final fallback: treat as user message
|
||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
||||
|
||||
context_model = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
group_id=group_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
# Generate chunks
|
||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("No dialogs to process for ingestion.")
|
||||
return False
|
||||
|
||||
# Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
||||
out_path = save_chunk_output_path or default_path
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save chunking results: {e}")
|
||||
|
||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
||||
if not llm_available:
|
||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
||||
return False
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化并运行 ExtractionOrchestrator
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
||||
# 保存原始的 _assign_extracted_data 方法
|
||||
original_assign = orchestrator._assign_extracted_data
|
||||
|
||||
def clean_temporal_value(value):
|
||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
||||
return None
|
||||
return value
|
||||
|
||||
async def patched_assign_extracted_data(*args, **kwargs):
|
||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
||||
result = await original_assign(*args, **kwargs)
|
||||
|
||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
||||
for dialog in result:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
# 清理 valid_at 和 invalid_at
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
return result
|
||||
|
||||
# 替换方法
|
||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
||||
|
||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
||||
original_create = orchestrator._create_nodes_and_edges
|
||||
|
||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
||||
# 最后一次清理,确保万无一失
|
||||
for dialog in dialog_data_list_arg:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
|
||||
return await original_create(dialog_data_list_arg)
|
||||
|
||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run 返回 7 个元素的元组
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
|
||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
||||
|
||||
# Step G: 生成记忆摘要
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step H: Save to Neo4j
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
entity_edges=entity_entity_edges,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
# Save memory summaries separately
|
||||
if summaries:
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
if success:
|
||||
print("Successfully saved extracted data to Neo4j!")
|
||||
else:
|
||||
print("Failed to save data to Neo4j")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Failed to save data to Neo4j: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_context_processing(args):
|
||||
"""Handle context-based processing from command line arguments."""
|
||||
contexts = []
|
||||
|
||||
if args.contexts:
|
||||
contexts.extend(args.contexts)
|
||||
|
||||
if args.context_file:
|
||||
try:
|
||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
||||
contexts.extend(line.strip() for line in f if line.strip())
|
||||
except Exception as e:
|
||||
print(f"Error reading context file: {e}")
|
||||
return False
|
||||
|
||||
if not contexts:
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_group_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
group_id=group_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Successfully processed and saved contexts to Neo4j!")
|
||||
else:
|
||||
print("Failed to process contexts.")
|
||||
|
||||
return success
|
||||
@@ -1,575 +0,0 @@
|
||||
"""
|
||||
LoCoMo Benchmark Script
|
||||
|
||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
||||
in a clean, maintainable way.
|
||||
|
||||
Usage:
|
||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
retrieve_relevant_information,
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
group_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
reset_group: bool = False,
|
||||
skip_ingest: bool = False,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run LoCoMo benchmark evaluation.
|
||||
|
||||
This function orchestrates the complete evaluation pipeline:
|
||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
||||
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
|
||||
3. For each question:
|
||||
- Retrieve relevant information
|
||||
- Generate answer using LLM
|
||||
- Calculate metrics
|
||||
4. Aggregate results and save to file
|
||||
|
||||
Note: By default, only the first conversation is ingested into the database,
|
||||
and only QA pairs from that conversation are evaluated. This ensures that
|
||||
all questions have corresponding memory in the database for retrieval.
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
group_id: Database group ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
reset_group: Whether to clear and re-ingest data (not implemented)
|
||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default group_id if not provided
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
# Determine data path
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
# Fallback to current directory
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" Group ID: {group_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
print(f" Data path: {data_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Step 1: Load LoCoMo data
|
||||
print("📂 Loading LoCoMo dataset...")
|
||||
try:
|
||||
# Only load QA pairs from the first conversation (index 0)
|
||||
# since we only ingest the first conversation into the database
|
||||
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
|
||||
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load data: {e}")
|
||||
return {
|
||||
"error": f"Data loading failed: {e}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Step 2: Extract conversations and ingest if needed
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" Group ID: {group_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
conversations = extract_conversations(data_path, max_dialogues=1)
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
group_id=group_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Ingestion completed successfully\n")
|
||||
else:
|
||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ingestion failed: {e}")
|
||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
||||
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
print("✅ Clients initialized\n")
|
||||
|
||||
# Step 4: Process questions
|
||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Tracking variables
|
||||
latencies_search: List[float] = []
|
||||
latencies_llm: List[float] = []
|
||||
context_counts: List[int] = []
|
||||
context_chars: List[int] = []
|
||||
context_tokens: List[int] = []
|
||||
|
||||
# Metric lists
|
||||
f1_scores: List[float] = []
|
||||
bleu1_scores: List[float] = []
|
||||
jaccard_scores: List[float] = []
|
||||
locomo_f1_scores: List[float] = []
|
||||
|
||||
# Per-category tracking
|
||||
category_counts: Dict[str, int] = {}
|
||||
category_f1: Dict[str, List[float]] = {}
|
||||
category_bleu1: Dict[str, List[float]] = {}
|
||||
category_jaccard: Dict[str, List[float]] = {}
|
||||
category_locomo_f1: Dict[str, List[float]] = {}
|
||||
|
||||
# Detailed samples
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
# Fixed anchor date for temporal resolution
|
||||
anchor_date = datetime(2023, 5, 8)
|
||||
|
||||
try:
|
||||
for idx, item in enumerate(qa_items, 1):
|
||||
question = item.get("question", "")
|
||||
ground_truth = item.get("answer", "")
|
||||
category = get_category_name(item)
|
||||
|
||||
# Ensure ground truth is a string
|
||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
||||
print(f"❓ Question: {question}")
|
||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
||||
|
||||
# Step 4a: Retrieve relevant information
|
||||
t_search_start = time.time()
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
group_id=group_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
embedder=embedder
|
||||
)
|
||||
t_search_end = time.time()
|
||||
search_latency = (t_search_end - t_search_start) * 1000
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Retrieval failed: {e}")
|
||||
retrieved_info = []
|
||||
search_latency = 0.0
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
# Step 4b: Select and format context
|
||||
context_text = select_and_format_information(
|
||||
retrieved_info=retrieved_info,
|
||||
question=question,
|
||||
max_chars=context_char_budget
|
||||
)
|
||||
|
||||
# Resolve temporal references
|
||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
||||
|
||||
# Add reference date to context
|
||||
if context_text:
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
||||
else:
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# Track context statistics
|
||||
context_counts.append(len(retrieved_info))
|
||||
context_chars.append(len(context_text))
|
||||
context_tokens.append(len(context_text.split()))
|
||||
|
||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
||||
|
||||
# Step 4c: Generate answer with LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
||||
}
|
||||
]
|
||||
|
||||
t_llm_start = time.time()
|
||||
try:
|
||||
response = await llm_client.chat(messages=messages)
|
||||
t_llm_end = time.time()
|
||||
llm_latency = (t_llm_end - t_llm_start) * 1000
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Extract prediction from response
|
||||
if hasattr(response, 'content'):
|
||||
prediction = response.content.strip()
|
||||
elif isinstance(response, dict):
|
||||
prediction = response["choices"][0]["message"]["content"].strip()
|
||||
else:
|
||||
prediction = "Unknown"
|
||||
|
||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLM failed: {e}")
|
||||
prediction = "Unknown"
|
||||
llm_latency = 0.0
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Step 4d: Calculate metrics
|
||||
f1_val = f1_score(prediction, ground_truth_str)
|
||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
||||
|
||||
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
|
||||
if item.get("category") == 1:
|
||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
||||
else:
|
||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
||||
|
||||
# Accumulate metrics
|
||||
f1_scores.append(f1_val)
|
||||
bleu1_scores.append(bleu1_val)
|
||||
jaccard_scores.append(jaccard_val)
|
||||
locomo_f1_scores.append(locomo_f1_val)
|
||||
|
||||
# Track by category
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
category_f1.setdefault(category, []).append(f1_val)
|
||||
category_bleu1.setdefault(category, []).append(bleu1_val)
|
||||
category_jaccard.setdefault(category, []).append(jaccard_val)
|
||||
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
|
||||
|
||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
||||
print()
|
||||
|
||||
# Save sample details
|
||||
samples.append({
|
||||
"question": question,
|
||||
"ground_truth": ground_truth_str,
|
||||
"prediction": prediction,
|
||||
"category": category,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"bleu1": bleu1_val,
|
||||
"jaccard": jaccard_val,
|
||||
"locomo_f1": locomo_f1_val
|
||||
},
|
||||
"retrieval": {
|
||||
"num_docs": len(retrieved_info),
|
||||
"context_length": len(context_text)
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_latency,
|
||||
"llm_ms": llm_latency
|
||||
}
|
||||
})
|
||||
|
||||
finally:
|
||||
# Close connector
|
||||
await connector.close()
|
||||
|
||||
# Step 5: Aggregate results
|
||||
print(f"\n{'='*60}")
|
||||
print("📊 Aggregating Results")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Overall metrics
|
||||
overall_metrics = {
|
||||
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
|
||||
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
|
||||
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
|
||||
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
by_category: Dict[str, Dict[str, Any]] = {}
|
||||
for cat in category_counts:
|
||||
f1_list = category_f1.get(cat, [])
|
||||
b1_list = category_bleu1.get(cat, [])
|
||||
j_list = category_jaccard.get(cat, [])
|
||||
lf_list = category_locomo_f1.get(cat, [])
|
||||
|
||||
by_category[cat] = {
|
||||
"count": category_counts[cat],
|
||||
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
|
||||
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
|
||||
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
|
||||
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
|
||||
}
|
||||
|
||||
# Latency statistics
|
||||
latency = {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm)
|
||||
}
|
||||
|
||||
# Context statistics
|
||||
context_stats = {
|
||||
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
|
||||
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
|
||||
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
|
||||
}
|
||||
|
||||
# Build result dictionary
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"overall_metrics": overall_metrics,
|
||||
"by_category": by_category,
|
||||
"latency": latency,
|
||||
"context_stats": context_stats,
|
||||
"samples": samples
|
||||
}
|
||||
|
||||
# Step 6: Save results
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"results"
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate timestamped filename
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ Results saved to: {output_path}\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
print("📊 Printing results to console instead:\n")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse command-line arguments and run benchmark.
|
||||
|
||||
This function provides a CLI interface for running LoCoMo benchmarks
|
||||
with configurable parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LoCoMo benchmark evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_size",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of QA pairs to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database group ID for retrieval (uses default if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
help="Search strategy to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_limit",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Maximum number of documents to retrieve per query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_char_budget",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Maximum characters for context"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_group",
|
||||
action="store_true",
|
||||
help="Clear and re-ingest data (not implemented)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ingest",
|
||||
action="store_true",
|
||||
help="Skip data ingestion and use existing data in Neo4j"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results (uses default if not specified)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
reset_group=args.reset_group,
|
||||
skip_ingest=args.skip_ingest,
|
||||
output_dir=args.output_dir
|
||||
))
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print("❌ Benchmark Failed!")
|
||||
print(f"{'='*60}")
|
||||
print(f"Error: {result['error']}")
|
||||
return
|
||||
|
||||
print("🎉 Benchmark Complete!")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Final Results:")
|
||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
||||
|
||||
if result.get('context_stats'):
|
||||
print("\n📈 Context Statistics:")
|
||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
||||
|
||||
if result.get('latency'):
|
||||
print("\n⏱️ Latency Statistics:")
|
||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
||||
|
||||
if result.get('by_category'):
|
||||
print("\n📂 Results by Category:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" Count: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
LoCoMo-specific metric calculations.
|
||||
|
||||
This module provides clean, simplified implementations of metrics used for
|
||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
Normalize text for LoCoMo evaluation.
|
||||
|
||||
Normalization steps:
|
||||
- Convert to lowercase
|
||||
- Remove commas
|
||||
- Remove stop words (a, an, the, and)
|
||||
- Remove punctuation
|
||||
- Normalize whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text string with consistent formatting
|
||||
|
||||
Examples:
|
||||
>>> normalize_text("The cat, and the dog")
|
||||
'cat dog'
|
||||
>>> normalize_text("Hello, World!")
|
||||
'hello world'
|
||||
"""
|
||||
# Ensure input is a string
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove commas
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
|
||||
# Remove stop words
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
|
||||
# Remove punctuation (keep only word characters and whitespace)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces to single space)
|
||||
text = " ".join(text.split())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for single-answer questions.
|
||||
|
||||
Uses token-level precision and recall based on normalized text.
|
||||
Treats tokens as sets (no duplicate counting).
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer
|
||||
ground_truth: Correct answer
|
||||
|
||||
Returns:
|
||||
F1 score between 0.0 and 1.0
|
||||
|
||||
Examples:
|
||||
>>> locomo_f1_score("Paris", "Paris")
|
||||
1.0
|
||||
>>> locomo_f1_score("The cat", "cat")
|
||||
1.0
|
||||
>>> locomo_f1_score("dog", "cat")
|
||||
0.0
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Normalize and tokenize
|
||||
pred_tokens = normalize_text(pred_str).split()
|
||||
truth_tokens = normalize_text(truth_str).split()
|
||||
|
||||
# Handle empty cases
|
||||
if not pred_tokens or not truth_tokens:
|
||||
return 0.0
|
||||
|
||||
# Convert to sets for comparison
|
||||
pred_set = set(pred_tokens)
|
||||
truth_set = set(truth_tokens)
|
||||
|
||||
# Calculate true positives (intersection)
|
||||
true_positives = len(pred_set & truth_set)
|
||||
|
||||
# Calculate precision and recall
|
||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
||||
|
||||
# Calculate F1 score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for multi-answer questions.
|
||||
|
||||
Handles comma-separated answers by:
|
||||
1. Splitting both prediction and ground truth by commas
|
||||
2. For each ground truth answer, finding the best matching prediction
|
||||
3. Averaging the F1 scores across all ground truth answers
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
||||
|
||||
Returns:
|
||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
||||
|
||||
Examples:
|
||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
||||
1.0
|
||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
||||
0.5
|
||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
||||
0.5
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Split by commas and strip whitespace
|
||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
||||
|
||||
# Handle empty cases
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
|
||||
# For each ground truth, find the best matching prediction
|
||||
f1_scores = []
|
||||
for gt in ground_truths:
|
||||
# Calculate F1 with each prediction and take the maximum
|
||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
||||
f1_scores.append(best_f1)
|
||||
|
||||
# Return average F1 across all ground truths
|
||||
return sum(f1_scores) / len(f1_scores)
|
||||
|
||||
|
||||
def get_category_name(item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract and normalize category name from QA item.
|
||||
|
||||
Handles both numeric categories (1-4) and string categories with various formats.
|
||||
Supports multiple field names: "cat", "category", "type".
|
||||
|
||||
Category mapping:
|
||||
- 1 or "multi-hop" -> "Multi-Hop"
|
||||
- 2 or "temporal" -> "Temporal"
|
||||
- 3 or "open domain" -> "Open Domain"
|
||||
- 4 or "single-hop" -> "Single-Hop"
|
||||
|
||||
Args:
|
||||
item: QA item dictionary containing category information
|
||||
|
||||
Returns:
|
||||
Standardized category name or "unknown" if not found
|
||||
|
||||
Examples:
|
||||
>>> get_category_name({"category": 1})
|
||||
'Multi-Hop'
|
||||
>>> get_category_name({"cat": "temporal"})
|
||||
'Temporal'
|
||||
>>> get_category_name({"type": "Single-Hop"})
|
||||
'Single-Hop'
|
||||
"""
|
||||
# Numeric category mapping
|
||||
CATEGORY_MAP = {
|
||||
1: "Multi-Hop",
|
||||
2: "Temporal",
|
||||
3: "Open Domain",
|
||||
4: "Single-Hop",
|
||||
}
|
||||
|
||||
# String category aliases (case-insensitive)
|
||||
TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
# Try "cat" field first (string category)
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return TYPE_ALIASES.get(lower, name)
|
||||
|
||||
# Try "category" field (can be int or string)
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
||||
elif isinstance(cat_num, str) and cat_num.strip():
|
||||
lower = cat_num.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
||||
|
||||
# Try "type" field as fallback
|
||||
cat_type = item.get("type")
|
||||
if isinstance(cat_type, str) and cat_type.strip():
|
||||
lower = cat_type.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
||||
|
||||
return "unknown"
|
||||
@@ -1,810 +0,0 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
# 回退到本地实现
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
|
||||
r_counts = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
|
||||
clipped = 0
|
||||
p_counts = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
|
||||
return bp * precision
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p = set(_loc_normalize(pred_str).split())
|
||||
r = set(_loc_normalize(ref_str).split())
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
# 添加 evaluation 目录路径
|
||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
||||
if evaluation_dir not in sys.path:
|
||||
sys.path.insert(0, evaluation_dir)
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
t = str(text) if text is not None else ""
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
p_tokens = _loc_normalize(prediction).split()
|
||||
g_tokens = _loc_normalize(ground_truth).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
||||
"""根据问题复杂度和进度动态调整检索参数"""
|
||||
|
||||
# 分析问题复杂度
|
||||
word_count = len(question.split())
|
||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
||||
|
||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
||||
progress_factor = question_index / total_questions
|
||||
|
||||
base_limit = 12
|
||||
if has_temporal and has_multi_hop:
|
||||
base_limit = 20
|
||||
elif word_count > 8:
|
||||
base_limit = 16
|
||||
|
||||
# 随着测试进行,逐渐收紧检索范围
|
||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
||||
|
||||
# 动态调整最大字符数
|
||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
||||
|
||||
return {
|
||||
"limit": adjusted_limit,
|
||||
"max_chars": int(max_chars)
|
||||
}
|
||||
|
||||
|
||||
class EnhancedEvaluationMonitor:
|
||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
||||
self.question_count = 0
|
||||
self.reset_interval = reset_interval
|
||||
self.performance_threshold = performance_threshold
|
||||
self.consecutive_low_scores = 0
|
||||
self.performance_history = []
|
||||
self.recent_f1_scores = []
|
||||
|
||||
def should_reset_connections(self, current_f1=None):
|
||||
"""基于计数和性能双重判断"""
|
||||
# 定期重置
|
||||
if self.question_count % self.reset_interval == 0:
|
||||
return True
|
||||
|
||||
# 性能驱动的重置
|
||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
||||
self.consecutive_low_scores += 1
|
||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
||||
print("🚨 连续低分,触发紧急重置")
|
||||
self.consecutive_low_scores = 0
|
||||
return True
|
||||
else:
|
||||
self.consecutive_low_scores = 0
|
||||
|
||||
return False
|
||||
|
||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
||||
"""记录性能指标,检测衰减"""
|
||||
self.performance_history.append({
|
||||
'index': question_index,
|
||||
'metrics': metrics,
|
||||
'context_length': context_length,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# 记录最近的F1分数
|
||||
self.recent_f1_scores.append(metrics['f1'])
|
||||
if len(self.recent_f1_scores) > 5:
|
||||
self.recent_f1_scores.pop(0)
|
||||
|
||||
def get_recent_performance(self):
|
||||
"""获取近期平均性能"""
|
||||
if not self.recent_f1_scores:
|
||||
return 0.5
|
||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
||||
|
||||
def get_performance_trend(self):
|
||||
"""分析性能趋势"""
|
||||
if len(self.performance_history) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
||||
|
||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
||||
|
||||
if recent_avg < earlier_avg * 0.8:
|
||||
return "degrading"
|
||||
elif recent_avg > earlier_avg * 1.1:
|
||||
return "improving"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
||||
|
||||
# 基础参数
|
||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
||||
|
||||
# 性能自适应调整
|
||||
if recent_performance < 0.5: # 近期表现差
|
||||
# 增加检索范围,尝试获取更多上下文
|
||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
elif recent_performance > 0.8: # 近期表现好
|
||||
# 收紧检索,提高精度
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
# 中间阶段特殊处理
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用更精确的检索策略")
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
||||
|
||||
return base_params
|
||||
|
||||
|
||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
||||
"""考虑问题序列位置的智能选择"""
|
||||
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 在序列中间阶段使用更严格的筛选
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
||||
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用严格上下文筛选")
|
||||
|
||||
# 提取问题关键词
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
# 只保留高度相关的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts:
|
||||
context_lower = context.lower()
|
||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
||||
|
||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
||||
if any(char.isdigit() for char in context):
|
||||
relevance_score += 2
|
||||
|
||||
# 提高阈值:只有得分>=3的上下文才保留
|
||||
if relevance_score >= 3:
|
||||
filtered_contexts.append(context)
|
||||
else:
|
||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
||||
|
||||
contexts = filtered_contexts
|
||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
||||
|
||||
# 使用原有的智能选择逻辑
|
||||
return smart_context_selection(contexts, question, max_chars)
|
||||
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
current_file = os.path.abspath(__file__)
|
||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
qa_items = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
items = qa_items[:20] # 测试多少个问题
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
|
||||
# 分类别统计
|
||||
category = "Unknown"
|
||||
if item.get("category") == 1:
|
||||
category = "Multi-Hop"
|
||||
elif item.get("category") == 2:
|
||||
category = "Temporal"
|
||||
elif item.get("category") == 3:
|
||||
category = "Open Domain"
|
||||
elif item.get("category") == 4:
|
||||
category = "Single-Hop"
|
||||
|
||||
# 增强的检索参数
|
||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
||||
search_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
||||
|
||||
# 使用项目标准的混合检索方法
|
||||
t0 = time.time()
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用统一的搜索服务
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
print("🔀 使用混合搜索服务...")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
group_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
embedding_id=SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
search_time = (t1 - t0) * 1000
|
||||
|
||||
# 增强的上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
# 使用增强的上下文选择
|
||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览(不只是第一条)
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
# 🔍 调试:检查答案是否在上下文中
|
||||
if ref_str and ref_str.strip():
|
||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# LLM 回答
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM 生成失败: {e}")
|
||||
pred = "Unknown"
|
||||
t3 = time.time()
|
||||
llm_time = (t3 - t2) * 1000
|
||||
|
||||
# 计算指标 - 使用导入的指标函数
|
||||
f1_val = f1_score(pred, ref_str)
|
||||
bleu1_val = bleu1(pred, ref_str)
|
||||
jaccard_val = jaccard(pred, ref_str)
|
||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
||||
|
||||
# 更新统计
|
||||
total_f1 += f1_val
|
||||
total_bleu1 += bleu1_val
|
||||
total_jaccard += jaccard_val
|
||||
total_loc_f1 += loc_f1_val
|
||||
total_context_length += len(context_text)
|
||||
total_retrieved_docs += len(contexts_all)
|
||||
|
||||
if category not in category_stats:
|
||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
||||
|
||||
category_stats[category]["count"] += 1
|
||||
category_stats[category]["f1_sum"] += f1_val
|
||||
category_stats[category]["b1_sum"] += bleu1_val
|
||||
category_stats[category]["j_sum"] += jaccard_val
|
||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
||||
|
||||
# 记录性能指标
|
||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
||||
|
||||
# 保存结果
|
||||
question_result = {
|
||||
"question": q,
|
||||
"ground_truth": ref_str,
|
||||
"prediction": pred,
|
||||
"category": category,
|
||||
"metrics": metrics,
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": max_chars,
|
||||
"recent_performance": recent_performance
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_time,
|
||||
"llm_ms": llm_time
|
||||
}
|
||||
}
|
||||
|
||||
results["questions"].append(question_result)
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
if n > 0:
|
||||
results["overall_metrics"] = {
|
||||
"f1": total_f1 / n,
|
||||
"b1": total_bleu1 / n,
|
||||
"j": total_jaccard / n,
|
||||
"loc_f1": total_loc_f1 / n
|
||||
}
|
||||
|
||||
for category, stats in category_stats.items():
|
||||
count = stats["count"]
|
||||
results["category_metrics"][category] = {
|
||||
"count": count,
|
||||
"f1": stats["f1_sum"] / count,
|
||||
"bleu1": stats["b1_sum"] / count,
|
||||
"jaccard": stats["j_sum"] / count,
|
||||
"loc_f1": stats["loc_f1_sum"] / count
|
||||
}
|
||||
|
||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
||||
|
||||
# 分析性能趋势
|
||||
results["performance_trend"] = monitor.get_performance_trend()
|
||||
results["reset_interval"] = monitor.reset_interval
|
||||
results["total_questions_processed"] = monitor.question_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
||||
print("📋 增强特性:")
|
||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
||||
print(" - 连续性能监控:实时检测性能衰减")
|
||||
|
||||
result = asyncio.run(run_enhanced_evaluation())
|
||||
|
||||
print("\n📊 最终评估结果:")
|
||||
print("总体指标:")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
||||
|
||||
print("\n分类别指标:")
|
||||
for category, metrics in result['category_metrics'].items():
|
||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
||||
|
||||
print("\n检索统计:")
|
||||
stats = result['retrieval_stats']
|
||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
||||
|
||||
print(f"\n性能趋势: {result['performance_trend']}")
|
||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
||||
|
||||
|
||||
# 保存结果到指定目录
|
||||
# 使用代码文件所在目录的绝对路径
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_file_dir, "results")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n详细结果已保存到: {output_file}")
|
||||
@@ -1,626 +0,0 @@
|
||||
"""
|
||||
LoCoMo Utilities Module
|
||||
|
||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
||||
- Data loading from JSON files
|
||||
- Conversation extraction for ingestion
|
||||
- Temporal reference resolution
|
||||
- Context selection and formatting
|
||||
- Retrieval wrapper functions
|
||||
- Ingestion wrapper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
def load_locomo_data(
|
||||
data_path: str,
|
||||
sample_size: int,
|
||||
conversation_index: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load LoCoMo dataset from JSON file.
|
||||
|
||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
||||
object contains a "qa" list of question-answer pairs.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
||||
|
||||
Returns:
|
||||
List of QA item dictionaries, each containing:
|
||||
- question: str
|
||||
- answer: str
|
||||
- category: int (1-4)
|
||||
- evidence: List[str]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data_path does not exist
|
||||
json.JSONDecodeError: If file is not valid JSON
|
||||
IndexError: If conversation_index is out of range
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(raw, list):
|
||||
# Only load QA pairs from the specified conversation
|
||||
if conversation_index < len(raw):
|
||||
entry = raw[conversation_index]
|
||||
if isinstance(entry, dict) and "qa" in entry:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has {len(raw)} conversations."
|
||||
)
|
||||
else:
|
||||
# Fallback: single object with qa list
|
||||
if conversation_index == 0:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has only 1 conversation."
|
||||
)
|
||||
|
||||
# Return only the requested sample size
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
||||
so they can be ingested into the memory system. Each conversation is formatted
|
||||
as a multi-line string with "role: message" format.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
Each string contains multiple lines in format "role: message"
|
||||
|
||||
Example output:
|
||||
[
|
||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
||||
]
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# Ensure we have a list of entries
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
contents: List[str] = []
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
|
||||
if not isinstance(conv, dict):
|
||||
continue
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
# Collect all session_* messages
|
||||
for key, val in sorted(conv.items()):
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
for msg in val:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("speaker") or "User"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
|
||||
This function converts relative time expressions (like "today", "yesterday",
|
||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
||||
|
||||
Supported patterns:
|
||||
- today, yesterday, tomorrow
|
||||
- X days ago, in X days
|
||||
- last week, next week
|
||||
|
||||
Args:
|
||||
text: Text containing temporal references
|
||||
anchor_date: Reference date for resolution (datetime object)
|
||||
|
||||
Returns:
|
||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
||||
|
||||
Example:
|
||||
>>> anchor = datetime(2023, 5, 8)
|
||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
||||
"I saw him 2023-05-07"
|
||||
"""
|
||||
# Ensure input is a string
|
||||
t = str(text) if text is not None else ""
|
||||
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(
|
||||
r"\btoday\b",
|
||||
anchor_date.date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\byesterday\b",
|
||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\btomorrow\b",
|
||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# X days ago
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
||||
|
||||
# in X days
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(
|
||||
r"\b(\d+)\s+days?\s+ago\b",
|
||||
_ago_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bin\s+(\d+)\s+days?\b",
|
||||
_in_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# last week / next week (approximate as 7 days)
|
||||
t = re.sub(
|
||||
r"\blast\s+week\b",
|
||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def select_and_format_information(
|
||||
retrieved_info: List[str],
|
||||
question: str,
|
||||
max_chars: int = 8000
|
||||
) -> str:
|
||||
"""
|
||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
||||
|
||||
This function scores each piece of retrieved information based on keyword matching
|
||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
||||
|
||||
Scoring criteria:
|
||||
- Keyword matches (higher weight for multiple occurrences)
|
||||
- Context length (moderate length preferred)
|
||||
- Position (earlier contexts get bonus points)
|
||||
|
||||
Args:
|
||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
||||
question: Question being answered
|
||||
max_chars: Maximum total characters to include in final prompt
|
||||
|
||||
Returns:
|
||||
Formatted string combining the most relevant information for LLM prompt.
|
||||
Contexts are separated by double newlines.
|
||||
|
||||
Example:
|
||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
||||
>>> question = "Where did Alice go?"
|
||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
||||
"""
|
||||
if not retrieved_info:
|
||||
return ""
|
||||
|
||||
# Extract question keywords (filter out stop words and short words)
|
||||
question_lower = question.lower()
|
||||
stop_words = {
|
||||
'what', 'when', 'where', 'who', 'why', 'how',
|
||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
||||
}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {
|
||||
word for word in question_words
|
||||
if word not in stop_words and len(word) > 2
|
||||
}
|
||||
|
||||
# Score each context
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(retrieved_info):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# Keyword matching score
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# Multiple occurrences increase score
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# Length score (prefer moderate length)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000:
|
||||
score += 5
|
||||
elif context_len >= 2000:
|
||||
score += 2
|
||||
|
||||
# Position bonus (earlier contexts often more relevant)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# Sort by score (descending)
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Select contexts up to character limit
|
||||
selected = []
|
||||
total_chars = 0
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
else:
|
||||
# Try to include high-scoring context by truncating
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# Find lines with keywords
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
selected.append(truncated + "\n[Content truncated...]")
|
||||
total_chars += len(truncated)
|
||||
break
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
group_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
embedder: Any
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieve relevant information from memory graph for a question.
|
||||
|
||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
||||
returns relevant chunks, statements, and entity information that might help
|
||||
answer the question.
|
||||
|
||||
The function supports three search types:
|
||||
- "keyword": Full-text search using Cypher queries
|
||||
- "embedding": Vector similarity search using embeddings
|
||||
- "hybrid": Combination of keyword and embedding search with reranking
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
group_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
embedder: Embedder client instance
|
||||
|
||||
Returns:
|
||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
||||
Each string represents a piece of retrieved information.
|
||||
|
||||
Raises:
|
||||
Exception: If search fails (caught and returns empty list)
|
||||
"""
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# Embedding-based search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context from chunks
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add summaries
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities (limit to 3 to avoid noise)
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# Keyword-based search
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
|
||||
# Build context from dialogues
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add entity names
|
||||
if entities:
|
||||
entity_names = [
|
||||
str(e.get("name", "")).strip()
|
||||
for e in entities[:5]
|
||||
if e.get("name")
|
||||
]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid search with fallback to embedding
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# Handle flat structure (new API format)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Check if we got results
|
||||
if not (chunks or statements or entities or summaries):
|
||||
# Try nested structure (backward compatibility)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to embedding search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context (same for both hybrid and fallback)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
except Exception as e:
|
||||
# Return empty list on error
|
||||
contexts_all = []
|
||||
|
||||
return contexts_all
|
||||
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
group_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
group_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
group_id=group_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
@@ -1,878 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
def _loc_normalize(text: str) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
t = str(text) if text is not None else ""
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
# X days ago / in X days
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
# last week / next week(以7天近似)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
g_tokens = _loc_normalize(truth_str).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
||||
CATEGORY_MAP_NUM_TO_NAME = {
|
||||
4: "Single-Hop",
|
||||
1: "Multi-Hop",
|
||||
3: "Open Domain",
|
||||
2: "Temporal",
|
||||
}
|
||||
|
||||
_TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
def get_category_label(item: Dict[str, Any]) -> str:
|
||||
# 1) 直接用字符串 cat
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return _TYPE_ALIASES.get(lower, name)
|
||||
# 2) 数字 category 转名称
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
||||
# 3) 备用 type 字段
|
||||
t = item.get("type")
|
||||
if isinstance(t, str) and t.strip():
|
||||
lower = t.strip().lower()
|
||||
return _TYPE_ALIASES.get(lower, t.strip())
|
||||
return "unknown"
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_search_params_by_category(category: str):
|
||||
"""根据问题类别调整检索参数"""
|
||||
params_map = {
|
||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
||||
}
|
||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
||||
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 32,
|
||||
search_type: str = "hybrid", # 保持默认值不变
|
||||
output_path: str | None = None,
|
||||
skip_ingest_if_exists: bool = True,
|
||||
llm_timeout: float = 10.0,
|
||||
llm_max_retries: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
||||
|
||||
# === 保持原来的数据摄入逻辑 ===
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
# 只摄入前1条对话(保持原样)
|
||||
max_dialogues_to_ingest = 1
|
||||
contents: List[str] = []
|
||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
||||
|
||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
||||
|
||||
lines: List[str] = []
|
||||
if isinstance(conv, dict):
|
||||
# 收集所有 session_* 的消息
|
||||
session_count = 0
|
||||
for key, val in conv.items():
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
session_count += 1
|
||||
for msg in val:
|
||||
role = msg.get("speaker") or "用户"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
||||
|
||||
if not lines:
|
||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
||||
continue
|
||||
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
||||
|
||||
# 选择要评测的QA对(从所有对话中选取)
|
||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
||||
if isinstance(raw, list):
|
||||
for e_idx, entry in enumerate(raw):
|
||||
for qa in entry.get("qa", []):
|
||||
indexed_items.append((e_idx, qa))
|
||||
else:
|
||||
for qa in raw.get("qa", []):
|
||||
indexed_items.append((0, qa))
|
||||
|
||||
# 这里使用sample_size来限制评测的QA数量
|
||||
selected = indexed_items[:sample_size]
|
||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
||||
|
||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
||||
# === 修改结束 ===
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# connector initialized above
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 上下文诊断收集
|
||||
per_query_context_counts: List[int] = []
|
||||
per_query_context_avg_tokens: List[float] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_tokens_total: List[int] = []
|
||||
# 详细样本调试信息
|
||||
samples: List[Dict[str, Any]] = []
|
||||
# 通用指标
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
||||
loc_f1s: List[float] = []
|
||||
# Per-category aggregation
|
||||
cat_counts: Dict[str, int] = {}
|
||||
cat_f1s: Dict[str, List[float]] = {}
|
||||
cat_b1s: Dict[str, List[float]] = {}
|
||||
cat_jss: Dict[str, List[float]] = {}
|
||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
||||
try:
|
||||
for item in items:
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
# 确保答案是字符串
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
cat = get_category_label(item)
|
||||
|
||||
print(f"\n=== 处理问题: {q} ===")
|
||||
|
||||
# 根据类别调整检索参数
|
||||
search_params = get_search_params_by_category(cat)
|
||||
adjusted_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
||||
|
||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
||||
t0 = time.time()
|
||||
contexts_all: List[str] = []
|
||||
search_results = None # 保存完整的检索结果
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# 直接调用嵌入检索,包含三路数据
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# 直接调用关键词检索
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
||||
|
||||
# 构建上下文
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
# 实体处理(关键词检索的实体可能没有分数)
|
||||
if entities:
|
||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
||||
print("🔀 使用混合检索(带回退机制)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
||||
# 新的API返回扁平结构,直接从顶层获取结果
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 新API返回扁平结构:直接从顶层获取
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
||||
else:
|
||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
|
||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts_all:
|
||||
content = str(context)
|
||||
# 排除包含当前问题标准答案的上下文
|
||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
||||
print("🚫 过滤掉包含标准答案的上下文")
|
||||
continue
|
||||
filtered_contexts.append(context)
|
||||
|
||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
||||
contexts_all = filtered_contexts
|
||||
|
||||
# 输出完整的检索结果信息
|
||||
print("🔍 检索结果详情:")
|
||||
if search_results:
|
||||
output_data = {
|
||||
"statements": [
|
||||
{
|
||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
||||
"score": s.get("score", 0.0)
|
||||
}
|
||||
for s in (statements[:2] if 'statements' in locals() else [])
|
||||
],
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"group_id": d.get("group_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"name": e.get("name", ""),
|
||||
"entity_type": e.get("entity_type", ""),
|
||||
"score": e.get("score", 0.0)
|
||||
}
|
||||
for e in (entities[:2] if 'entities' in locals() else [])
|
||||
]
|
||||
}
|
||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(" 无检索结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {search_type}检索失败: {e}")
|
||||
contexts_all = []
|
||||
search_results = None
|
||||
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 使用智能上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# 记录上下文诊断信息
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
||||
|
||||
# LLM 提示词
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
# 使用异步调用
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
|
||||
# 计算指标(确保使用字符串)
|
||||
f1_val = common_f1(str(pred), ref_str)
|
||||
b1_val = bleu1(str(pred), ref_str)
|
||||
j_val = jaccard(str(pred), ref_str)
|
||||
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
# Accumulate by category
|
||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
||||
cat_jss.setdefault(cat, []).append(j_val)
|
||||
|
||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
||||
if item.get("category") in [2, 3, 4]:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
elif item.get("category") in [1]:
|
||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
||||
else:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
loc_f1s.append(loc_val)
|
||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
||||
|
||||
# 保存完整的检索结果信息
|
||||
samples.append({
|
||||
"question": q,
|
||||
"answer": ref_str,
|
||||
"category": cat,
|
||||
"prediction": pred,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val,
|
||||
"loc_f1": loc_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": adjusted_limit,
|
||||
"max_chars": max_chars
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": (t1 - t0) * 1000,
|
||||
"llm_ms": (t3 - t2) * 1000
|
||||
}
|
||||
})
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {ref_str}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
||||
|
||||
# Compute per-category averages and dispersion (std, iqr)
|
||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
k = (len(sorted_vals) - 1) * p
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
||||
if f == c:
|
||||
return sorted_vals[f]
|
||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
||||
|
||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
||||
for c in cat_counts:
|
||||
f_list = cat_f1s.get(c, [])
|
||||
b_list = cat_b1s.get(c, [])
|
||||
j_list = cat_jss.get(c, [])
|
||||
lf_list = cat_loc_f1s.get(c, [])
|
||||
j_sorted = sorted(j_list)
|
||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
||||
j_q75 = _percentile(j_sorted, 0.75)
|
||||
j_q25 = _percentile(j_sorted, 0.25)
|
||||
by_category[c] = {
|
||||
"count": cat_counts[c],
|
||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
||||
"j_std": j_std,
|
||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
||||
# 参考 LoCoMo 评测的类别专用 F1
|
||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
||||
}
|
||||
|
||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
||||
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": sum(f1s) / max(len(f1s), 1),
|
||||
"b1": sum(b1s) / max(len(b1s), 1),
|
||||
"j": sum(jss) / max(len(jss), 1),
|
||||
# LoCoMo 类别专用 F1 的总体
|
||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
||||
},
|
||||
"by_category": by_category,
|
||||
"category_counts": cat_counts,
|
||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
||||
"context": {
|
||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
if output_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存结果失败: {e}")
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
output_path=args.output_path,
|
||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
||||
llm_timeout=args.llm_timeout,
|
||||
llm_max_retries=args.llm_max_retries
|
||||
))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 最终评测结果:")
|
||||
print(f" 样本数量: {result['items']}")
|
||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
||||
|
||||
if result['by_category']:
|
||||
print("\n📈 按类别细分:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" 样本数: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,324 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
||||
if not contexts:
|
||||
return ""
|
||||
import re
|
||||
# 提取问题关键词(移除停用词)
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
# 评分
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
||||
parts: List[str] = []
|
||||
for turn in dialog_obj.get("dialog", []):
|
||||
speaker = turn.get("speaker", "")
|
||||
text = turn.get("text", "")
|
||||
if text:
|
||||
parts.append(f"{speaker}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
||||
if results is None:
|
||||
return []
|
||||
emb = []
|
||||
kw = []
|
||||
if isinstance(results.get("embedding_search"), dict):
|
||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
||||
elif isinstance(results.get("dialogues"), list):
|
||||
emb = results.get("dialogues", []) or []
|
||||
if isinstance(results.get("keyword_search"), dict):
|
||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
||||
seen = set()
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for d in emb:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
for d in kw:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
contexts_used: List[str] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
try:
|
||||
for item in items:
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
||||
contexts_all: List[str] = []
|
||||
if results:
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重与限制
|
||||
seen_texts = set()
|
||||
for d in all_dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
for s in all_statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
# 实体摘要(最多3个)
|
||||
names = []
|
||||
merged_entities = all_entities[:]
|
||||
for e in merged_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
else:
|
||||
dialogs = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
for d in dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
|
||||
# 智能选择并截断到预算
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text[:200])
|
||||
|
||||
# Call LLM (使用异步调用)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
t2 = time.time()
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
# Aggregate metrics
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"accuracy": acc,
|
||||
# Placeholders for extensibility
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"avg_context_tokens": ctx_avg_tokens,
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
)
|
||||
)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,576 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
tp = len(set(ps) & set(rs))
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
precision = tp / len(ps)
|
||||
recall = tp / len(rs)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
overlap = len([w for w in ps if w in rs])
|
||||
return overlap / max(len(ps), 1)
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
ps = set(pred.lower().split())
|
||||
rs = set(ref.lower().split())
|
||||
union = len(ps | rs)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return len(ps & rs) / union
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
||||
|
||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
||||
ql = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
||||
}
|
||||
words = re.findall(r"\b[\w-]+\b", ql)
|
||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
||||
# 去重保序
|
||||
seen = set()
|
||||
uniq = []
|
||||
for w in kws:
|
||||
if w not in seen:
|
||||
uniq.append(w)
|
||||
seen.add(w)
|
||||
if len(uniq) >= max_keywords:
|
||||
break
|
||||
return uniq
|
||||
|
||||
|
||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
||||
|
||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
||||
"""
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
tl = (ctx or "").lower()
|
||||
match_count = sum(1 for k in keywords if k in tl)
|
||||
length = len(ctx)
|
||||
score = match_count * 200 + min(length, 100000) / 100.0
|
||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
||||
return results[:max(top_n, 0)]
|
||||
|
||||
|
||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
||||
|
||||
|
||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
||||
items: List[Dict[str, Any]] = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
items.append(json.loads(line))
|
||||
except Exception:
|
||||
# 跳过坏行但不中断
|
||||
continue
|
||||
return items
|
||||
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 64,
|
||||
search_type: str = "embedding",
|
||||
data_path: str | None = None,
|
||||
start_index: int = 0,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
||||
|
||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
||||
- 支持 keyword / embedding / hybrid 三种检索
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
group_id = group_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
if not data_path:
|
||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
if os.path.exists(proj_path):
|
||||
data_path = proj_path
|
||||
elif os.path.exists(cwd_path):
|
||||
data_path = cwd_path
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
if sample_size is None or sample_size <= 0:
|
||||
items = all_items[start_index:]
|
||||
else:
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 评估循环
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 存储完整上下文文本用于统计
|
||||
contexts_used: List[str] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_counts: List[int] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
total_items = len(items)
|
||||
for idx, item in enumerate(items):
|
||||
if verbose:
|
||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
||||
results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
elif search_type == "keyword":
|
||||
# 关键词检索(直接调用 graph_search)
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
chunks = results.get("chunks", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
summaries = results.get("summaries", [])
|
||||
retrieved_counts = {
|
||||
"chunks": len(chunks),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
"summaries": len(summaries),
|
||||
}
|
||||
# 优先使用 chunks
|
||||
for c in chunks:
|
||||
text = str(c.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 statements
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 summaries
|
||||
for sm in summaries:
|
||||
text = str(sm.get("summary", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
||||
if contexts_all:
|
||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
||||
if analysis:
|
||||
print("📊 上下文相关性分析:")
|
||||
for a in analysis:
|
||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
||||
for i, ctx in enumerate(contexts_all[:10]):
|
||||
preview = str(ctx).replace("\n", " ")
|
||||
if len(preview) > 300:
|
||||
preview = preview[:300] + "..."
|
||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
||||
# 标注参考答案是否出现在任一上下文中
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower:
|
||||
hits = []
|
||||
for i, ctx in enumerate(contexts_all):
|
||||
if ref_lower in str(ctx).lower():
|
||||
hits.append(i+1)
|
||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
||||
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text)
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
|
||||
if verbose:
|
||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
||||
concat_preview = context_text.replace("\n", " ")
|
||||
if len(concat_preview) > 600:
|
||||
concat_preview = concat_preview[:600] + "..."
|
||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
||||
"3) Keep your answer brief and to the point;\n"
|
||||
"4) Do not add explanations or additional text beyond the answer."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
||||
if hasattr(resp, 'content'):
|
||||
pred = resp.content.strip()
|
||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
||||
pred = resp["choices"][0]["message"]["content"].strip()
|
||||
elif isinstance(resp, dict) and "content" in resp:
|
||||
pred = resp["content"].strip()
|
||||
elif isinstance(resp, str):
|
||||
pred = resp.strip()
|
||||
else:
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
||||
|
||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
||||
if pred.lower() in ["unknown", ""]:
|
||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
||||
except Exception as e:
|
||||
# 更详细的错误处理
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM调用异常: {e}")
|
||||
t3 = time.time()
|
||||
llm_ms = (t3 - t2) * 1000
|
||||
latencies_llm.append(llm_ms)
|
||||
|
||||
exact = exact_match(pred, reference)
|
||||
correct_flags.append(exact)
|
||||
f1_val = f1_score(str(pred), str(reference))
|
||||
b1_val = bleu1(str(pred), str(reference))
|
||||
j_val = jaccard(str(pred), str(reference))
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
if verbose:
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {reference}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
||||
|
||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
||||
samples.append({
|
||||
"question": str(question),
|
||||
"answer": str(reference),
|
||||
"prediction": str(pred),
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": context_char_budget
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_ms,
|
||||
"llm_ms": llm_ms
|
||||
}
|
||||
})
|
||||
|
||||
# 计算总体指标与聚合
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": ctx_avg_tokens,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_size = 0 if args.all else args.sample_size
|
||||
|
||||
verbose_flag = False if args.quiet else args.verbose
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
data_path=args.data_path,
|
||||
start_index=args.start_index,
|
||||
verbose=verbose_flag,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果保存
|
||||
out_path = args.output
|
||||
if not out_path:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 结果已保存: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 结果保存失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,150 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
||||
|
||||
|
||||
async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
group_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
llm_temperature: float | None = None,
|
||||
llm_max_tokens: int | None = None,
|
||||
search_type: str | None = None,
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(group_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
if start_index is not None:
|
||||
kwargs["start_index"] = start_index
|
||||
if max_contexts_per_item is not None:
|
||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
||||
return await run_longmemeval_test(**kwargs)
|
||||
raise ValueError(f"未知数据集: {dataset}")
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
||||
# 仅透传到 longmemeval;其他数据集忽略
|
||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run(
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.group_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
args.llm_temperature,
|
||||
args.llm_max_tokens,
|
||||
args.search_type,
|
||||
args.start_index,
|
||||
args.max_contexts_per_item,
|
||||
))
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果输出逻辑保持不变
|
||||
if args.output:
|
||||
out_path = args.output
|
||||
else:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
||||
|
||||
out_dir = os.path.dirname(out_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -187,11 +187,11 @@ class ChunkerClient:
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
@@ -201,9 +201,9 @@ class ChunkerClient:
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
@@ -212,13 +212,13 @@ class ChunkerClient:
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
@@ -228,14 +228,14 @@ class ChunkerClient:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
@@ -260,7 +260,7 @@ class ChunkerClient:
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
@@ -268,7 +268,7 @@ class ChunkerClient:
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
|
||||
@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology models
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -105,6 +111,9 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
||||
"""Parameters for temporal search queries in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
group_id: Group ID to filter search results (default: 'test')
|
||||
end_user_id: Group ID to filter search results (default: 'test')
|
||||
apply_id: Application ID to filter search results
|
||||
user_id: User ID to filter search results
|
||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||
limit: Maximum number of results to return (default: 3)
|
||||
"""
|
||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||
|
||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
||||
id: Unique identifier for the edge
|
||||
source: ID of the source node
|
||||
target: ID of the target node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
target: str = Field(..., description="The ID of the target node.")
|
||||
group_id: str = Field(..., description="The group ID of the edge.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the node
|
||||
name: Name of the node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
group_id: str = Field(..., description="The group ID of the node.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
@@ -421,7 +413,8 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class Statement(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the statement
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
end_user_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
@@ -73,7 +73,7 @@ class Statement(BaseModel):
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
|
||||
context: Full conversation context
|
||||
dialog_embedding: Optional embedding vector for the entire dialog
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
|
||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
|
||||
return []
|
||||
|
||||
def assign_group_id_to_statements(self) -> None:
|
||||
"""Assign this dialog's group_id to all statements in all chunks.
|
||||
"""Assign this dialog's end_user_id to all statements in all chunks.
|
||||
|
||||
This method updates statements that don't have a group_id set.
|
||||
This method updates statements that don't have a end_user_id set.
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.group_id is None:
|
||||
statement.group_id = self.group_id
|
||||
if statement.end_user_id is None:
|
||||
statement.end_user_id = self.end_user_id
|
||||
|
||||
135
api/app/core/memory/models/ontology_models.py
Normal file
135
api/app/core/memory/models/ontology_models.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Models for ontology classes and extraction responses.
|
||||
|
||||
This module contains Pydantic models for representing extracted ontology classes
|
||||
from scenario descriptions, following OWL ontology engineering standards.
|
||||
|
||||
Classes:
|
||||
OntologyClass: Represents an extracted ontology class
|
||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class OntologyClass(BaseModel):
|
||||
"""Represents an extracted ontology class from scenario description.
|
||||
|
||||
An ontology class represents an abstract category or concept in a domain,
|
||||
following OWL ontology engineering standards and naming conventions.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the ontology class
|
||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||
description: Textual description of the class
|
||||
examples: List of concrete instance examples of this class
|
||||
parent_class: Optional name of the parent class in the hierarchy
|
||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
description="Unique identifier for the ontology class"
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the class in PascalCase format"
|
||||
)
|
||||
name_chinese: Optional[str] = Field(
|
||||
None,
|
||||
description="Chinese translation of the class name"
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description="Description of the class"
|
||||
)
|
||||
examples: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of concrete instance examples"
|
||||
)
|
||||
parent_class: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the parent class in the hierarchy"
|
||||
)
|
||||
entity_type: str = Field(
|
||||
...,
|
||||
description="Type/category of the entity"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain this class belongs to"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_pascal_case(cls, v: str) -> str:
|
||||
"""Validate that the class name follows PascalCase convention.
|
||||
|
||||
PascalCase rules:
|
||||
- Must start with an uppercase letter
|
||||
- Cannot contain spaces
|
||||
- Should not contain special characters except underscores
|
||||
|
||||
Args:
|
||||
v: The class name to validate
|
||||
|
||||
Returns:
|
||||
The validated class name
|
||||
|
||||
Raises:
|
||||
ValueError: If the name doesn't follow PascalCase convention
|
||||
"""
|
||||
if not v:
|
||||
raise ValueError("Class name cannot be empty")
|
||||
|
||||
if not v[0].isupper():
|
||||
raise ValueError(
|
||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||
)
|
||||
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||
)
|
||||
|
||||
# Check for invalid characters (allow alphanumeric and underscore only)
|
||||
if not all(c.isalnum() or c == '_' for c in v):
|
||||
raise ValueError(
|
||||
f"Class name '{v}' contains invalid characters. "
|
||||
"Only alphanumeric characters and underscores are allowed"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class OntologyExtractionResponse(BaseModel):
|
||||
"""Response model for ontology extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting ontology classes from scenario descriptions.
|
||||
|
||||
Attributes:
|
||||
classes: List of extracted ontology classes
|
||||
domain: Domain/field the scenario belongs to
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
classes: List[OntologyClass] = Field(
|
||||
default_factory=list,
|
||||
description="List of extracted ontology classes"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain/field the scenario belongs to"
|
||||
)
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -396,13 +397,13 @@ def rerank_with_activation(
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
search_type: Type of search (keyword, embedding, hybrid)
|
||||
group_id: Group identifier for filtering
|
||||
end_user_id: Group identifier for filtering
|
||||
limit: Maximum number of results
|
||||
include: List of result types to include
|
||||
log_file: Deprecated parameter, kept for backward compatibility
|
||||
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
f"group_id={group_id}, limit={limit}, include={include}"
|
||||
f"end_user_id={end_user_id}, limit={limit}, include={include}"
|
||||
)
|
||||
|
||||
|
||||
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: str | None,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
@@ -715,7 +716,7 @@ async def run_hybrid_search(
|
||||
}
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, group_id, limit, include)
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
results = {}
|
||||
@@ -732,7 +733,7 @@ async def run_hybrid_search(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
)
|
||||
@@ -769,7 +770,7 @@ async def run_hybrid_search(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
@@ -916,9 +917,7 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -929,7 +928,7 @@ async def search_by_temporal(
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id
|
||||
- Optionally filters by end_user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
@@ -939,9 +938,7 @@ async def search_by_temporal(
|
||||
end_date = normalize_date_safe(end_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -950,9 +947,7 @@ async def search_by_temporal(
|
||||
})
|
||||
statements = await search_graph_by_temporal(
|
||||
connector=connector,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -964,9 +959,7 @@ async def search_by_temporal(
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
|
||||
invalid_date = normalize_date_safe(invalid_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
|
||||
statements = await search_graph_by_keyword_temporal(
|
||||
connector=connector,
|
||||
query_text=query_text,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
group_id: Optional[str] = "test",
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
|
||||
chunks = await search_graph_by_chunk_id(
|
||||
connector=connector,
|
||||
chunk_id=chunk_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
@@ -555,8 +555,8 @@ class DataPreprocessor:
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
|
||||
|
||||
# 获取group_id,如果不存在则生成默认值
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
# 获取end_user_id,如果不存在则生成默认值
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -574,7 +574,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
@@ -644,7 +644,7 @@ class DataPreprocessor:
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -657,7 +657,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
|
||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
if len(desc_b) > len(desc_a):
|
||||
canonical.description = desc_b
|
||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
def _extract_sources(txt: str) -> List[str]:
|
||||
sources: List[str] = []
|
||||
if not txt:
|
||||
return sources
|
||||
for line in str(txt).splitlines():
|
||||
ln = line.strip()
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
# def _extract_sources(txt: str) -> List[str]:
|
||||
# sources: List[str] = []
|
||||
# if not txt:
|
||||
# return sources
|
||||
# for line in str(txt).splitlines():
|
||||
# ln = line.strip()
|
||||
# 支持“来源:”或“来源:”前缀
|
||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
sources.append(content)
|
||||
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
# if m:
|
||||
# content = m.group(1).strip()
|
||||
# if content:
|
||||
# sources.append(content)
|
||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||
if not sources and txt.strip():
|
||||
sources.append(txt.strip())
|
||||
return sources
|
||||
# if not sources and txt.strip():
|
||||
# sources.append(txt.strip())
|
||||
# return sources
|
||||
try:
|
||||
src_a = _extract_sources(fact_a)
|
||||
src_b = _extract_sources(fact_b)
|
||||
seen = set()
|
||||
merged_sources: List[str] = []
|
||||
for s in src_a + src_b:
|
||||
if s and s not in seen:
|
||||
seen.add(s)
|
||||
merged_sources.append(s)
|
||||
if merged_sources:
|
||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
elif fact_b and not fact_a:
|
||||
canonical.fact_summary = fact_b
|
||||
# src_a = _extract_sources(fact_a)
|
||||
# src_b = _extract_sources(fact_b)
|
||||
# seen = set()
|
||||
# merged_sources: List[str] = []
|
||||
# for s in src_a + src_b:
|
||||
# if s and s not in seen:
|
||||
# seen.add(s)
|
||||
# merged_sources.append(s)
|
||||
# if merged_sources:
|
||||
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
# elif fact_b and not fact_a:
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
# 兜底:若解析失败,保留较长文本
|
||||
if len(fact_b) > len(fact_a):
|
||||
canonical.fact_summary = fact_b
|
||||
# if len(fact_b) > len(fact_a):
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -199,7 +202,7 @@ def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
@@ -210,8 +213,8 @@ def accurate_match(
|
||||
for ent in entity_nodes:
|
||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
||||
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
|
||||
if key not in canonical_map:
|
||||
canonical_map[key] = ent
|
||||
id_redirect[ent.id] = ent.id
|
||||
@@ -223,11 +226,11 @@ def accurate_match(
|
||||
id_redirect[ent.id] = canonical.id
|
||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||
try:
|
||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": canonical.id,
|
||||
"group_id": canonical.group_id,
|
||||
"end_user_id": canonical.end_user_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
@@ -596,7 +599,7 @@ def fuzzy_match(
|
||||
b = deduped_entities[j]
|
||||
|
||||
# 跳过不同业务组的实体
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
@@ -671,7 +674,7 @@ def fuzzy_match(
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
fuzzy_merge_records.append(
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
|
||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||
)
|
||||
except Exception:
|
||||
@@ -779,7 +782,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
# 记录 LLM 融合日志
|
||||
try:
|
||||
llm_records.append(
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||
except Exception:
|
||||
@@ -847,7 +850,7 @@ async def LLM_disamb_decision(
|
||||
id_redirect[k] = a.id
|
||||
try:
|
||||
disamb_records.append(
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||
desc_a = (getattr(a, "description", "") or "")
|
||||
desc_b = (getattr(b, "description", "") or "")
|
||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
score_a = len(desc_a) + len(fact_a)
|
||||
score_b = len(desc_b) + len(fact_b)
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
# score_a = len(desc_a) + len(fact_a)
|
||||
# score_b = len(desc_b) + len(fact_b)
|
||||
score_a = len(desc_a)
|
||||
score_b = len(desc_b)
|
||||
if score_a != score_b:
|
||||
return 0 if score_a >= score_b else 1
|
||||
return 0
|
||||
@@ -174,7 +177,7 @@ async def _judge_pair(
|
||||
pass
|
||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||
@@ -235,7 +240,7 @@ async def _judge_pair_disamb(
|
||||
except Exception:
|
||||
pass
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
"name_embed_sim": name_embed_sim,
|
||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
prompt = render_entity_dedup_prompt(
|
||||
@@ -317,8 +324,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
a = entity_nodes[i]
|
||||
for j in range(i + 1, len(entity_nodes)):
|
||||
b = entity_nodes[j]
|
||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
# 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复)
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||
@@ -474,7 +481,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
- max_rounds: upper bound for iterative passes (default 3)
|
||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
||||
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
|
||||
|
||||
Returns:
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
@@ -509,7 +516,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
|
||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||
"""
|
||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
|
||||
Args:
|
||||
nodes: 实体节点列表
|
||||
@@ -519,7 +526,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
"""
|
||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||
for e in nodes:
|
||||
gid = getattr(e, "group_id", None)
|
||||
gid = getattr(e, "end_user_id", None)
|
||||
groups.setdefault(str(gid), []).append(e)
|
||||
blocks: List[List[ExtractedEntityNode]] = []
|
||||
for gid, arr in groups.items():
|
||||
@@ -559,7 +566,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||
current_nodes = _collapse_nodes(current_nodes)
|
||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
||||
# 步骤2:分块(按end_user_id分块,避免跨组处理)
|
||||
blocks = _partition_blocks(current_nodes)
|
||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||
break
|
||||
@@ -645,7 +652,7 @@ async def llm_disambiguate_pairs_iterative(
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
# 必须同组
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
ta = getattr(a, "entity_type", None)
|
||||
tb = getattr(b, "entity_type", None)
|
||||
|
||||
@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
return ExtractedEntityNode(
|
||||
id=row.get("id"),
|
||||
name=row.get("name") or "",
|
||||
group_id=row.get("group_id") or "",
|
||||
end_user_id=row.get("end_user_id") or "",
|
||||
user_id=row.get("user_id") or "",
|
||||
apply_id=row.get("apply_id") or "",
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -72,14 +72,15 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
description=row.get("description") or "",
|
||||
aliases=row.get("aliases") or [],
|
||||
name_embedding=row.get("name_embedding") or [],
|
||||
fact_summary=row.get("fact_summary") or "",
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=row.get("fact_summary") or "",
|
||||
connect_strength=row.get("connect_strength") or "",
|
||||
)
|
||||
|
||||
|
||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||
connector: Neo4jConnector,
|
||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
@@ -88,7 +89,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
||||
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
|
||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||
"""
|
||||
@@ -102,7 +103,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
|
||||
]
|
||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||
connector=connector, group_id=group_id,
|
||||
connector=connector, end_user_id=end_user_id,
|
||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||
)
|
||||
|
||||
@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
if pipeline_config is None:
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
# 先探测 end_user_id,决定报告写入策略
|
||||
end_user_id: Optional[str] = None
|
||||
for dd in dialog_data_list:
|
||||
group_id = getattr(dd, "group_id", None)
|
||||
if group_id:
|
||||
end_user_id = getattr(dd, "end_user_id", None)
|
||||
if end_user_id:
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
|
||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
if connector:
|
||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
entity_nodes=dedup_entity_nodes,
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
else:
|
||||
print("Skip second-layer dedup: missing group_id")
|
||||
print("Skip second-layer dedup: missing end_user_id")
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
||||
for c_idx, chunk in enumerate(dialog.chunks):
|
||||
all_chunks.append((chunk, dialog.group_id, dialogue_content))
|
||||
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
chunk, end_user_id, dialogue_content = chunk_data
|
||||
try:
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
# 加载MemoryConfig
|
||||
memory_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
if memory_config and not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
if not memory_config or not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
|
||||
total_statements += 1
|
||||
# 只处理用户的陈述句 (role 为 "user")
|
||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||
all_statements.append((statement, data_config))
|
||||
all_statements.append((statement, memory_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
filtered_statements += 1
|
||||
|
||||
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
|
||||
id=dialog_data.id,
|
||||
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
||||
ref_id=dialog_data.ref_id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
|
||||
id=chunk.id,
|
||||
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
||||
dialog_id=dialog_data.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
|
||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edge = StatementChunkEdge(
|
||||
source=statement.id,
|
||||
target=chunk.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1072,13 +1064,16 @@ class ExtractionOrchestrator:
|
||||
if statement.triplet_extraction_info:
|
||||
triplet_info = statement.triplet_extraction_info
|
||||
|
||||
# 创建实体索引到ID的映射
|
||||
# 创建实体索引到ID的映射(支持多种索引方式)
|
||||
entity_idx_to_id = {}
|
||||
|
||||
# 创建实体节点
|
||||
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||
# 映射实体索引到实体ID
|
||||
# 映射实体索引到实体ID(使用多个键以提高容错性)
|
||||
# 1. 使用实体自己的 entity_idx
|
||||
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||
# 2. 使用枚举索引(从0开始)
|
||||
entity_idx_to_id[entity_idx] = entity.id
|
||||
|
||||
if entity.id not in entity_id_set:
|
||||
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
|
||||
@@ -1090,14 +1085,13 @@ class ExtractionOrchestrator:
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1112,9 +1106,7 @@ class ExtractionOrchestrator:
|
||||
source=statement.id,
|
||||
target=entity.id,
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1134,9 +1126,7 @@ class ExtractionOrchestrator:
|
||||
relation_type=triplet.predicate,
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1163,9 +1153,18 @@ class ExtractionOrchestrator:
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
f"object_id={triplet.object_id}, statement_id={statement.id}"
|
||||
# 改进的警告信息,包含更多调试信息
|
||||
missing_subject = "subject" if not subject_entity_id else ""
|
||||
missing_object = "object" if not object_entity_id else ""
|
||||
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||
|
||||
logger.debug(
|
||||
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||
f"predicate={triplet.predicate}, "
|
||||
f"statement_id={statement.id}, "
|
||||
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -1763,14 +1762,14 @@ class ExtractionOrchestrator:
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
end_user_id: str = "group_1",
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> List[DialogData]:
|
||||
"""从测试数据生成分块对话
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
indices: 要处理的数据索引列表(可选)
|
||||
|
||||
Returns:
|
||||
@@ -1834,7 +1833,7 @@ async def get_chunked_dialogs(
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=data['id'],
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
metadata=dialog_metadata,
|
||||
)
|
||||
|
||||
@@ -1936,7 +1935,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
async def get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "default",
|
||||
end_user_id: str = "default",
|
||||
user_id: str = "default",
|
||||
apply_id: str = "default",
|
||||
indices: Optional[List[int]] = None,
|
||||
@@ -1948,7 +1947,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
indices: 要处理的数据索引列表
|
||||
@@ -1976,11 +1975,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
# 设置 group_id, user_id, apply_id
|
||||
# 设置 end_user_id
|
||||
for dd in preprocessed_data:
|
||||
dd.group_id = group_id
|
||||
dd.user_id = user_id
|
||||
dd.apply_id = apply_id
|
||||
dd.end_user_id = end_user_id
|
||||
|
||||
# 步骤2: 语义剪枝
|
||||
try:
|
||||
|
||||
@@ -8,4 +8,5 @@
|
||||
- TemporalExtractor: 时间信息提取
|
||||
- EmbeddingGenerator: 嵌入向量生成
|
||||
- MemorySummaryGenerator: 记忆摘要生成
|
||||
- OntologyExtractor: 本体类提取
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,34 @@ from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
# 支持的语言列表和默认回退值
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
FALLBACK_LANGUAGE = "en"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
"""
|
||||
if language is None:
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
lang = str(language).lower().strip()
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
"""Structured response for summary generation per chunk.
|
||||
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
llm_client,
|
||||
language: str = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
from app.core.config import settings
|
||||
|
||||
# 如果没有指定语言,从配置中读取,并校验有效性
|
||||
if language is None:
|
||||
language = settings.DEFAULT_LANGUAGE
|
||||
language = validate_language(language)
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
# 根据语言设置默认标题
|
||||
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
||||
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
||||
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
||||
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
||||
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
# 1. 渲染Jinja2提示词模板,传递语言参数
|
||||
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
title = result_data.get("title", UNKNOWN_TITLE)
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
||||
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
||||
return (ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
try:
|
||||
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
|
||||
from app.core.config import settings
|
||||
language = validate_language(settings.DEFAULT_LANGUAGE)
|
||||
|
||||
# Render prompt via Jinja2 for a single chunk
|
||||
prompt_content = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk.content,
|
||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
|
||||
try:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
llm_client=llm_client,
|
||||
language=language
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
@@ -193,9 +241,9 @@ async def _process_chunk_summary(
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
user_id=dialog.end_user_id,
|
||||
apply_id=dialog.end_user_id,
|
||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||
created_at=datetime.now(),
|
||||
expired_at=datetime(9999, 12, 31),
|
||||
|
||||
@@ -0,0 +1,482 @@
|
||||
"""Ontology class extraction from scenario descriptions using LLM.
|
||||
|
||||
This module provides the OntologyExtractor class for extracting ontology classes
|
||||
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
||||
with two-layer validation (string validation + OWL semantic validation).
|
||||
|
||||
Classes:
|
||||
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyExtractor:
|
||||
"""Extractor for ontology classes from scenario descriptions.
|
||||
|
||||
This extractor uses LLM to identify abstract classes and concepts from
|
||||
natural language scenario descriptions, following OWL ontology engineering
|
||||
standards. It performs two-layer validation:
|
||||
1. String validation (naming conventions, reserved words, duplicates)
|
||||
2. OWL semantic validation (consistency checking, circular inheritance)
|
||||
|
||||
Attributes:
|
||||
llm_client: OpenAI client for LLM calls
|
||||
validator: String validator for class names and descriptions
|
||||
owl_validator: OWL validator for semantic validation
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""Initialize the OntologyExtractor.
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for LLM processing
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.validator = OntologyValidator()
|
||||
self.owl_validator = OWLValidator()
|
||||
|
||||
logger.info("OntologyExtractor initialized")
|
||||
|
||||
async def extract_ontology_classes(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str] = None,
|
||||
max_classes: int = 15,
|
||||
min_classes: int = 5,
|
||||
enable_owl_validation: bool = True,
|
||||
llm_temperature: float = 0.3,
|
||||
llm_max_tokens: int = 2000,
|
||||
max_description_length: int = 500,
|
||||
timeout: Optional[float] = None,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Extract ontology classes from a scenario description.
|
||||
|
||||
This is the main extraction method that orchestrates the entire process:
|
||||
1. Call LLM to extract ontology classes
|
||||
2. Perform first-layer validation (string validation and cleaning)
|
||||
3. Perform second-layer validation (OWL semantic validation)
|
||||
4. Filter invalid classes based on validation errors
|
||||
5. Return validated ontology classes
|
||||
|
||||
Args:
|
||||
scenario: Natural language scenario description
|
||||
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
min_classes: Minimum number of classes to extract (default: 5)
|
||||
enable_owl_validation: Whether to enable OWL validation (default: True)
|
||||
llm_temperature: LLM temperature parameter (default: 0.3)
|
||||
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
||||
max_description_length: Maximum description length (default: 500)
|
||||
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse containing validated ontology classes
|
||||
|
||||
Raises:
|
||||
ValueError: If scenario is empty or invalid
|
||||
asyncio.TimeoutError: If extraction times out
|
||||
|
||||
Examples:
|
||||
>>> extractor = OntologyExtractor(llm_client)
|
||||
>>> response = await extractor.extract_ontology_classes(
|
||||
... scenario="A hospital manages patient records...",
|
||||
... domain="Healthcare",
|
||||
... max_classes=10,
|
||||
... timeout=30.0
|
||||
... )
|
||||
>>> len(response.classes)
|
||||
7
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input
|
||||
if not scenario or not scenario.strip():
|
||||
logger.error("Scenario description is empty")
|
||||
raise ValueError("Scenario description cannot be empty")
|
||||
|
||||
scenario = scenario.strip()
|
||||
|
||||
logger.info(
|
||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||
f"timeout={timeout}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Call LLM for extraction with timeout
|
||||
logger.info("Step 1: Calling LLM for ontology extraction")
|
||||
llm_start_time = time.time()
|
||||
|
||||
if timeout is not None:
|
||||
# Wrap LLM call with timeout
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.error(
|
||||
f"LLM extraction timed out after {timeout} seconds "
|
||||
f"(actual duration: {llm_duration:.2f}s)"
|
||||
)
|
||||
# Return empty response on timeout
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
else:
|
||||
# No timeout specified, call directly
|
||||
response = await self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
)
|
||||
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.info(
|
||||
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
# Step 2: First-layer validation (string validation and cleaning)
|
||||
logger.info("Step 2: Performing first-layer validation (string validation)")
|
||||
validation_start_time = time.time()
|
||||
|
||||
response = self._validate_and_clean(
|
||||
response=response,
|
||||
max_description_length=max_description_length,
|
||||
)
|
||||
|
||||
validation_duration = time.time() - validation_start_time
|
||||
logger.info(
|
||||
f"After first-layer validation: {len(response.classes)} classes remain "
|
||||
f"(validation took {validation_duration:.2f}s)"
|
||||
)
|
||||
|
||||
# Check if we have enough classes after first-layer validation
|
||||
if len(response.classes) < min_classes:
|
||||
logger.warning(
|
||||
f"Only {len(response.classes)} classes remain after validation, "
|
||||
f"which is below minimum of {min_classes}"
|
||||
)
|
||||
|
||||
# Step 3: Second-layer validation (OWL semantic validation)
|
||||
if enable_owl_validation and response.classes:
|
||||
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
||||
owl_start_time = time.time()
|
||||
|
||||
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
||||
classes=response.classes,
|
||||
)
|
||||
|
||||
owl_duration = time.time() - owl_start_time
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
||||
)
|
||||
|
||||
# Filter invalid classes based on errors
|
||||
response = self._filter_invalid_classes(
|
||||
response=response,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"After second-layer validation: {len(response.classes)} classes remain"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
||||
else:
|
||||
if not enable_owl_validation:
|
||||
logger.info("Step 3: OWL validation disabled, skipping")
|
||||
else:
|
||||
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
||||
|
||||
# Calculate total duration
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# Log extraction statistics
|
||||
logger.info(
|
||||
f"Ontology extraction completed - "
|
||||
f"final_class_count={len(response.classes)}, "
|
||||
f"domain={response.domain}, "
|
||||
f"total_duration={total_duration:.2f}s, "
|
||||
f"llm_duration={llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout errors
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction timed out after {timeout} seconds "
|
||||
f"(total duration: {total_duration:.2f}s)",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty response on failure
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
|
||||
async def _call_llm_for_extraction(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str],
|
||||
max_classes: int,
|
||||
llm_temperature: float,
|
||||
llm_max_tokens: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Call LLM to extract ontology classes from scenario.
|
||||
|
||||
This method renders the extraction prompt using the Jinja2 template
|
||||
and calls the LLM with structured output to get ontology classes.
|
||||
|
||||
Args:
|
||||
scenario: Scenario description text
|
||||
domain: Optional domain hint
|
||||
max_classes: Maximum number of classes to extract
|
||||
llm_temperature: LLM temperature parameter
|
||||
llm_max_tokens: LLM max tokens parameter
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse from LLM
|
||||
|
||||
Raises:
|
||||
Exception: If LLM call fails
|
||||
"""
|
||||
try:
|
||||
# Render prompt using template
|
||||
prompt_content = await render_ontology_extraction_prompt(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=OntologyExtractionResponse.model_json_schema(),
|
||||
)
|
||||
|
||||
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert ontology engineer specializing in knowledge "
|
||||
"representation and OWL standards. Extract ontology classes from "
|
||||
"scenario descriptions following the provided instructions. "
|
||||
"Return valid JSON conforming to the schema."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content,
|
||||
},
|
||||
]
|
||||
|
||||
# Call LLM with structured output
|
||||
logger.debug(
|
||||
f"Calling LLM with temperature={llm_temperature}, "
|
||||
f"max_tokens={llm_max_tokens}"
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM extraction failed: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def _validate_and_clean(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
max_description_length: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Perform first-layer validation: string validation and cleaning.
|
||||
|
||||
This method validates and cleans the extracted ontology classes:
|
||||
1. Validate class names (PascalCase, no reserved words)
|
||||
2. Sanitize invalid class names
|
||||
3. Truncate long descriptions
|
||||
4. Remove duplicate classes
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse from LLM
|
||||
max_description_length: Maximum description length
|
||||
|
||||
Returns:
|
||||
Cleaned OntologyExtractionResponse
|
||||
"""
|
||||
if not response.classes:
|
||||
logger.debug("No classes to validate")
|
||||
return response
|
||||
|
||||
logger.debug(f"Validating {len(response.classes)} classes")
|
||||
|
||||
validated_classes = []
|
||||
|
||||
for ontology_class in response.classes:
|
||||
# Validate class name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
||||
)
|
||||
|
||||
# Attempt to sanitize
|
||||
sanitized_name = self.validator.sanitize_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
||||
)
|
||||
|
||||
# Update class name
|
||||
ontology_class.name = sanitized_name
|
||||
|
||||
# Re-validate sanitized name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
sanitized_name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.error(
|
||||
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
||||
"Skipping this class."
|
||||
)
|
||||
continue
|
||||
|
||||
# Truncate description if too long
|
||||
if ontology_class.description:
|
||||
original_length = len(ontology_class.description)
|
||||
ontology_class.description = self.validator.truncate_description(
|
||||
ontology_class.description,
|
||||
max_length=max_description_length,
|
||||
)
|
||||
|
||||
if len(ontology_class.description) < original_length:
|
||||
logger.debug(
|
||||
f"Truncated description for '{ontology_class.name}': "
|
||||
f"{original_length} -> {len(ontology_class.description)} chars"
|
||||
)
|
||||
|
||||
validated_classes.append(ontology_class)
|
||||
|
||||
# Remove duplicates (case-insensitive)
|
||||
original_count = len(validated_classes)
|
||||
validated_classes = self.validator.remove_duplicates(validated_classes)
|
||||
|
||||
if len(validated_classes) < original_count:
|
||||
logger.info(
|
||||
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
||||
)
|
||||
|
||||
# Return cleaned response
|
||||
return OntologyExtractionResponse(
|
||||
classes=validated_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
def _filter_invalid_classes(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
errors: List[str],
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Filter invalid classes based on OWL validation errors.
|
||||
|
||||
This method analyzes OWL validation errors and removes classes
|
||||
that caused validation failures (e.g., circular inheritance,
|
||||
inconsistencies).
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse to filter
|
||||
errors: List of error messages from OWL validation
|
||||
|
||||
Returns:
|
||||
Filtered OntologyExtractionResponse
|
||||
"""
|
||||
if not errors:
|
||||
return response
|
||||
|
||||
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
||||
|
||||
# Extract class names mentioned in errors
|
||||
invalid_class_names = set()
|
||||
|
||||
for error in errors:
|
||||
# Look for class names in error messages
|
||||
for ontology_class in response.classes:
|
||||
if ontology_class.name in error:
|
||||
invalid_class_names.add(ontology_class.name)
|
||||
logger.debug(
|
||||
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
||||
)
|
||||
|
||||
# Filter out invalid classes
|
||||
if invalid_class_names:
|
||||
original_count = len(response.classes)
|
||||
|
||||
filtered_classes = [
|
||||
c for c in response.classes
|
||||
if c.name not in invalid_class_names
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
||||
f"{invalid_class_names}"
|
||||
)
|
||||
|
||||
return OntologyExtractionResponse(
|
||||
classes=filtered_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -82,12 +82,12 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
Args:
|
||||
chunk: Chunk object to process
|
||||
group_id: Group ID to assign to all statements in this chunk
|
||||
end_user_id: Group ID to assign to all statements in this chunk
|
||||
dialogue_content: Full dialogue content to provide as context
|
||||
|
||||
Returns:
|
||||
@@ -158,7 +158,7 @@ class StatementExtractor:
|
||||
temporal_info=temporal_type,
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
speaker=chunk_speaker,
|
||||
)
|
||||
|
||||
@@ -184,10 +184,10 @@ class StatementExtractor:
|
||||
|
||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||
|
||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
||||
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
|
||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||
results = await asyncio.gather(
|
||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
||||
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
@@ -225,7 +225,7 @@ class StatementExtractor:
|
||||
for i, statement in enumerate(statements, 1):
|
||||
f.write(f"Statement {i}:\n")
|
||||
f.write(f"Id: {statement.id}\n")
|
||||
f.write(f"Group Id: {statement.group_id}\n")
|
||||
f.write(f"Group Id: {statement.end_user_id}\n")
|
||||
f.write(f"Content: {statement.statement}\n")
|
||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||
@@ -298,7 +298,7 @@ class StatementExtractor:
|
||||
|
||||
dialog_sections.append({
|
||||
"dialog_id": dialog.ref_id,
|
||||
"group_id": dialog.group_id,
|
||||
"end_user_id": dialog.end_user_id,
|
||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||
"strong": strong_relations,
|
||||
"weak": weak_relations,
|
||||
@@ -312,7 +312,7 @@ class StatementExtractor:
|
||||
for idx, section in enumerate(dialog_sections, 1):
|
||||
f.write(f"Dialog {idx}:\n")
|
||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
|
||||
f.write("Content:\n")
|
||||
f.write(f"{section.get('content', '')}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
@@ -132,7 +132,7 @@ class TemporalExtractor:
|
||||
prompt_logger.info("")
|
||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -25,6 +25,15 @@ class TripletExtractor:
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _get_language(self) -> str:
|
||||
"""Get the configured language for entity descriptions
|
||||
|
||||
Returns:
|
||||
Language code ("zh" or "en")
|
||||
"""
|
||||
from app.core.config import settings
|
||||
return settings.DEFAULT_LANGUAGE
|
||||
|
||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||
"""Process a single statement and return extracted triplets and entities"""
|
||||
# Render the prompt using helper function
|
||||
@@ -40,7 +49,8 @@ class TripletExtractor:
|
||||
statement=statement.statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language()
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
@@ -116,7 +126,7 @@ class TripletExtractor:
|
||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -75,7 +75,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -91,7 +91,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选,用于过滤)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
|
||||
Returns:
|
||||
@@ -123,7 +123,7 @@ class AccessHistoryManager:
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
@@ -142,7 +142,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -172,7 +172,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_ids: List[str],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -184,7 +184,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
|
||||
Returns:
|
||||
@@ -202,7 +202,7 @@ class AccessHistoryManager:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
@@ -235,7 +235,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
@@ -249,14 +249,14 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
@@ -305,7 +305,7 @@ class AccessHistoryManager:
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -313,7 +313,7 @@ class AccessHistoryManager:
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
@@ -329,16 +329,16 @@ class AccessHistoryManager:
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
"""
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
if end_user_id:
|
||||
params["end_user_id"] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
@@ -351,7 +351,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -387,7 +387,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
@@ -401,7 +401,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
@@ -411,7 +411,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -419,7 +419,7 @@ class AccessHistoryManager:
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
@@ -457,8 +457,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
SET n += $repair_data
|
||||
RETURN n
|
||||
@@ -468,8 +468,8 @@ class AccessHistoryManager:
|
||||
'node_id': node_id,
|
||||
'repair_data': repair_data
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -491,7 +491,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
@@ -499,7 +499,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
@@ -507,8 +507,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id,
|
||||
n.importance_score as importance_score,
|
||||
@@ -519,8 +519,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
params = {'node_id': node_id}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -585,7 +585,7 @@ class AccessHistoryManager:
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
update_data: Dict[str, Any],
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
@@ -597,7 +597,7 @@ class AccessHistoryManager:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
update_data: 更新数据
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
@@ -606,13 +606,13 @@ class AccessHistoryManager:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
read_query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
@@ -624,8 +624,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if group_id:
|
||||
read_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
@@ -656,8 +656,8 @@ class AccessHistoryManager:
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
@@ -695,8 +695,8 @@ class AccessHistoryManager:
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if group_id:
|
||||
update_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
@@ -720,7 +720,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,9 +11,10 @@ Functions:
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
||||
|
||||
def load_actr_config_from_db(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库加载 ACT-R 配置参数
|
||||
|
||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
||||
从 PostgreSQL 的 memory_config 表读取配置参数,
|
||||
并计算派生参数(如 forgetting_rate)。
|
||||
|
||||
Args:
|
||||
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
|
||||
|
||||
# 从数据库加载配置
|
||||
try:
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None:
|
||||
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
|
||||
|
||||
def create_actr_calculator_from_config(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> ACTRCalculator:
|
||||
"""
|
||||
从数据库配置创建 ACTRCalculator 实例
|
||||
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
|
||||
Examples:
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> db = Session()
|
||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
||||
>>> # 使用计算器
|
||||
>>> activation = calculator.calculate_memory_activation(...)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
@@ -16,6 +16,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
@@ -66,10 +67,10 @@ class ForgettingScheduler:
|
||||
|
||||
async def run_forgetting_cycle(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_merge_batch_size: int = 100,
|
||||
min_days_since_access: int = 30,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -77,7 +78,7 @@ class ForgettingScheduler:
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
@@ -107,19 +108,19 @@ class ForgettingScheduler:
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
f"开始遗忘周期: group_id={group_id}, "
|
||||
f"开始遗忘周期: end_user_id={end_user_id}, "
|
||||
f"max_batch={max_merge_batch_size}, "
|
||||
f"min_days={min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤1:统计遗忘前的节点数量
|
||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
||||
nodes_before = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||
|
||||
# 步骤2:识别可遗忘的节点对
|
||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
min_days_since_access=min_days_since_access
|
||||
)
|
||||
|
||||
@@ -213,7 +214,7 @@ class ForgettingScheduler:
|
||||
'statement_text': pair['statement_text'],
|
||||
'statement_activation': pair['statement_activation'],
|
||||
'statement_importance': pair['statement_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
entity_node = {
|
||||
@@ -222,7 +223,7 @@ class ForgettingScheduler:
|
||||
'entity_type': pair['entity_type'],
|
||||
'entity_activation': pair['entity_activation'],
|
||||
'entity_importance': pair['entity_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
# 融合节点
|
||||
@@ -262,7 +263,7 @@ class ForgettingScheduler:
|
||||
continue
|
||||
|
||||
# 步骤6:统计遗忘后的节点数量
|
||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
||||
nodes_after = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||
|
||||
# 步骤7:生成遗忘报告
|
||||
@@ -315,7 +316,7 @@ class ForgettingScheduler:
|
||||
|
||||
async def _count_knowledge_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计知识层节点总数
|
||||
@@ -323,7 +324,7 @@ class ForgettingScheduler:
|
||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
|
||||
Returns:
|
||||
int: 知识层节点总数
|
||||
@@ -333,16 +334,16 @@ class ForgettingScheduler:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
params = {}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -90,7 +91,7 @@ class ForgettingStrategy:
|
||||
|
||||
async def find_forgettable_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
min_days_since_access: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -102,7 +103,7 @@ class ForgettingStrategy:
|
||||
3. Statement 和 Entity 之间存在关系边
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
|
||||
Returns:
|
||||
@@ -136,8 +137,8 @@ class ForgettingStrategy:
|
||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN s.id as statement_id,
|
||||
@@ -159,8 +160,8 @@ class ForgettingStrategy:
|
||||
'threshold': self.forgetting_threshold,
|
||||
'cutoff_time': cutoff_time_iso
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -176,7 +177,7 @@ class ForgettingStrategy:
|
||||
self,
|
||||
statement_node: Dict[str, Any],
|
||||
entity_node: Dict[str, Any],
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -247,8 +248,8 @@ class ForgettingStrategy:
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
# 获取 end_user_id(从 statement 或 entity 节点)
|
||||
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
@@ -325,7 +326,7 @@ class ForgettingStrategy:
|
||||
last_access_time: $current_time,
|
||||
access_count: 1,
|
||||
version: 1,
|
||||
group_id: $group_id,
|
||||
end_user_id: $end_user_id,
|
||||
created_at: datetime($current_time),
|
||||
merged_at: datetime($current_time)
|
||||
})
|
||||
@@ -423,7 +424,7 @@ class ForgettingStrategy:
|
||||
'inherited_activation': inherited_activation,
|
||||
'inherited_importance': inherited_importance,
|
||||
'current_time': current_time_iso,
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -462,7 +463,7 @@ class ForgettingStrategy:
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -527,7 +528,7 @@ class ForgettingStrategy:
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
async def _get_llm_client(self, db, config_id: int):
|
||||
async def _get_llm_client(self, db, config_id: UUID):
|
||||
"""
|
||||
从数据库获取 LLM 客户端
|
||||
|
||||
@@ -539,11 +540,11 @@ class ForgettingStrategy:
|
||||
LLM 客户端实例,如果无法获取则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 从数据库读取配置
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None or db_config.llm_id is None:
|
||||
|
||||
@@ -37,7 +37,7 @@ __all__ = [
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
group_id: str | None = None,
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
@@ -54,7 +54,7 @@ async def run_hybrid_search(
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
group_id: 组ID过滤
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
@@ -104,7 +104,7 @@ async def run_hybrid_search(
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
|
||||
@@ -77,7 +77,7 @@
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# group_id: Optional[str] = None,
|
||||
# end_user_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
@@ -86,7 +86,7 @@
|
||||
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# group_id: 可选的组ID过滤
|
||||
# end_user_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
@@ -94,7 +94,7 @@
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
@@ -107,14 +107,14 @@
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
@@ -139,7 +139,7 @@
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
@@ -165,7 +165,7 @@
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
|
||||
@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
|
||||
self,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
limit: 结果限制
|
||||
**kwargs: 其他元数据
|
||||
|
||||
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
|
||||
metadata = {
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder_client,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
||||
key=lambda eid: (
|
||||
_strength_rank(eid),
|
||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
0 # 临时占位
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
target_keys = [
|
||||
"id",
|
||||
"statement",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"chunk_id",
|
||||
"created_at",
|
||||
"expired_at",
|
||||
@@ -75,7 +75,7 @@ async def get_data(result):
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
|
||||
@@ -62,7 +62,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
config_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
success: bool = True,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
user_id: 用户 ID(可选)
|
||||
group_id: 组 ID(可选)
|
||||
end_user_id: 组 ID(可选)
|
||||
success: 是否成功
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"CONFIG_LOAD config_id={config_id} "
|
||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
||||
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
|
||||
f"result={result}"
|
||||
)
|
||||
if details:
|
||||
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
operation: str,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
success: bool = True,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
operation: 操作类型(WRITE, READ 等)
|
||||
config_id: 配置 ID
|
||||
group_id: 组 ID
|
||||
end_user_id: 组 ID
|
||||
success: 是否成功
|
||||
duration: 操作耗时(秒)
|
||||
error: 错误信息(可选)
|
||||
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"{operation.upper()} config_id={config_id} "
|
||||
f"group={group_id} result={result}"
|
||||
f"group={end_user_id} result={result}"
|
||||
)
|
||||
if duration is not None:
|
||||
msg += f" duration={duration:.2f}s"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user