Compare commits
259 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eab7225d83 | ||
|
|
1b853aa893 | ||
|
|
0159fdf149 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
a651ae6ed4 | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
7e56c09620 | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
2e7f6afe3f | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
9a4b1f0937 | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
e5e914903c | ||
|
|
7ba443afa5 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
2862db3534 | ||
|
|
bf3e30dac0 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
375660f232 | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
f6031baee4 | ||
|
|
75b3ea1f05 | ||
|
|
c818ba7bc7 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
8fb9e779a6 | ||
|
|
c5a794f1b5 | ||
|
|
3aa2cdd754 | ||
|
|
d93d52cf10 | ||
|
|
2abbd5a7fb | ||
|
|
2a10e9f7ee | ||
|
|
166d05afe9 | ||
|
|
2eff8d1962 | ||
|
|
93c9e76c4b | ||
|
|
021cb09b82 | ||
|
|
28e6939884 | ||
|
|
8847039d76 | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
2694576a32 | ||
|
|
e4f10670f6 | ||
|
|
1324ba3a49 | ||
|
|
73c7810310 | ||
|
|
d160076267 | ||
|
|
a53be31765 | ||
|
|
ed8c1c7c19 | ||
|
|
159c8d1ff9 | ||
|
|
8932d455d8 | ||
|
|
3af183f6c3 | ||
|
|
4475be51cc | ||
|
|
c3ea3b751b | ||
|
|
e2c67d0c5b | ||
|
|
87731090ca | ||
|
|
80ca247435 | ||
|
|
a5b8d3afa5 | ||
|
|
1f615a06ad | ||
|
|
4123560a98 | ||
|
|
5267bd60a5 | ||
|
|
f76bffb482 | ||
|
|
51185c83c9 | ||
|
|
f1f887faae | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
46f0f3cee9 | ||
|
|
e1f5607836 | ||
|
|
ebc41b2eec | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
3b4b474ce8 | ||
|
|
4534e46811 | ||
|
|
7bfa7b3f02 | ||
|
|
1cc34d8e62 | ||
|
|
2eff6b2e9d | ||
|
|
b046411302 | ||
|
|
6ab65b3626 | ||
|
|
cf321f9b09 | ||
|
|
8228d38859 | ||
|
|
c2e3110fa2 | ||
|
|
85681db7b7 | ||
|
|
1fc04c37d3 | ||
|
|
0fd8a122fb | ||
|
|
e3b6ede992 | ||
|
|
3601737869 | ||
|
|
9de6b4f151 | ||
|
|
4f4f55d67f | ||
|
|
714c624dc6 | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
1919580759 | ||
|
|
b27ffe57e6 | ||
|
|
c115bcde54 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
313f19eba4 | ||
|
|
c6bcf53fea | ||
|
|
86812b34d1 | ||
|
|
15f9c49418 | ||
|
|
6e18c92a13 | ||
|
|
7870c6c33f | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
45adb9627a | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
7219274d94 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
51680b7077 | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
bcc8b7ce3c | ||
|
|
1c7fe6d134 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
c93bcb8678 | ||
|
|
98b2da9123 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
5a0d3df689 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
6ac10a8297 | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ nltk_data/
|
|||||||
tika-server*.jar*
|
tika-server*.jar*
|
||||||
cl100k_base.tiktoken
|
cl100k_base.tiktoken
|
||||||
libssl*.deb
|
libssl*.deb
|
||||||
|
|
||||||
|
sandbox/lib/seccomp_python/target
|
||||||
|
sandbox/lib/seccomp_nodejs/target
|
||||||
|
|||||||
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
@@ -872,3 +872,44 @@ async def update_workflow_config(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_app_statistics(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
start_date: int,
|
||||||
|
end_date: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取应用统计数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
start_date: 开始时间戳(毫秒)
|
||||||
|
end_date: 结束时间戳(毫秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- daily_conversations: 每日会话数统计
|
||||||
|
- total_conversations: 总会话数
|
||||||
|
- daily_new_users: 每日新增用户数
|
||||||
|
- total_new_users: 总新增用户数
|
||||||
|
- daily_api_calls: 每日API调用次数
|
||||||
|
- total_api_calls: 总API调用次数
|
||||||
|
- daily_tokens: 每日token消耗
|
||||||
|
- total_tokens: 总token消耗
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
|
stats_service = AppStatisticsService(db)
|
||||||
|
|
||||||
|
result = stats_service.get_app_statistics(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=result)
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ Routes:
|
|||||||
GET /memory/config/emotion - 获取情绪引擎配置
|
GET /memory/config/emotion - 获取情绪引擎配置
|
||||||
POST /memory/config/emotion - 更新情绪引擎配置
|
POST /memory/config/emotion - 更新情绪引擎配置
|
||||||
"""
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -20,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
from app.services.emotion_config_service import EmotionConfigService
|
from app.services.emotion_config_service import EmotionConfigService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -32,11 +35,11 @@ router = APIRouter(
|
|||||||
|
|
||||||
class EmotionConfigQuery(BaseModel):
|
class EmotionConfigQuery(BaseModel):
|
||||||
"""情绪配置查询请求模型"""
|
"""情绪配置查询请求模型"""
|
||||||
config_id: int = Field(..., description="配置ID")
|
config_id: UUID = Field(..., description="配置ID")
|
||||||
|
|
||||||
class EmotionConfigUpdate(BaseModel):
|
class EmotionConfigUpdate(BaseModel):
|
||||||
"""情绪配置更新请求模型"""
|
"""情绪配置更新请求模型"""
|
||||||
config_id: int = Field(..., description="配置ID")
|
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||||
@@ -45,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
def get_emotion_config(
|
def get_emotion_config(
|
||||||
config_id: int = Query(..., description="配置ID"),
|
config_id: UUID|int = Query(..., description="配置ID"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -78,7 +81,7 @@ def get_emotion_config(
|
|||||||
f"用户 {current_user.username} 请求获取情绪配置",
|
f"用户 {current_user.username} 请求获取情绪配置",
|
||||||
extra={"config_id": config_id}
|
extra={"config_id": config_id}
|
||||||
)
|
)
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 初始化服务
|
# 初始化服务
|
||||||
config_service = EmotionConfigService(db)
|
config_service = EmotionConfigService(db)
|
||||||
|
|
||||||
@@ -157,6 +160,7 @@ def update_emotion_config(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
config.config_id=resolve_config_id(config.config_id, db)
|
||||||
try:
|
try:
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求更新情绪配置",
|
f"用户 {current_user.username} 请求更新情绪配置",
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ async def get_emotion_tags(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"emotion_type": request.emotion_type,
|
"emotion_type": request.emotion_type,
|
||||||
"start_date": request.start_date,
|
"start_date": request.start_date,
|
||||||
"end_date": request.end_date,
|
"end_date": request.end_date,
|
||||||
@@ -63,7 +63,7 @@ async def get_emotion_tags(
|
|||||||
|
|
||||||
# 调用服务层
|
# 调用服务层
|
||||||
data = await emotion_service.get_emotion_tags(
|
data = await emotion_service.get_emotion_tags(
|
||||||
end_user_id=request.group_id,
|
end_user_id=request.end_user_id,
|
||||||
emotion_type=request.emotion_type,
|
emotion_type=request.emotion_type,
|
||||||
start_date=request.start_date,
|
start_date=request.start_date,
|
||||||
end_date=request.end_date,
|
end_date=request.end_date,
|
||||||
@@ -73,7 +73,7 @@ async def get_emotion_tags(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
"情绪标签统计获取成功",
|
"情绪标签统计获取成功",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"total_count": data.get("total_count", 0),
|
"total_count": data.get("total_count", 0),
|
||||||
"tags_count": len(data.get("tags", []))
|
"tags_count": len(data.get("tags", []))
|
||||||
}
|
}
|
||||||
@@ -84,7 +84,7 @@ async def get_emotion_tags(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
f"获取情绪标签统计失败: {str(e)}",
|
f"获取情绪标签统计失败: {str(e)}",
|
||||||
extra={"group_id": request.group_id},
|
extra={"end_user_id": request.end_user_id},
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"emotion_type": request.emotion_type,
|
"emotion_type": request.emotion_type,
|
||||||
"limit": request.limit
|
"limit": request.limit
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
|
|||||||
|
|
||||||
# 调用服务层
|
# 调用服务层
|
||||||
data = await emotion_service.get_emotion_wordcloud(
|
data = await emotion_service.get_emotion_wordcloud(
|
||||||
end_user_id=request.group_id,
|
end_user_id=request.end_user_id,
|
||||||
emotion_type=request.emotion_type,
|
emotion_type=request.emotion_type,
|
||||||
limit=request.limit
|
limit=request.limit
|
||||||
)
|
)
|
||||||
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
"情绪词云数据获取成功",
|
"情绪词云数据获取成功",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"total_keywords": data.get("total_keywords", 0)
|
"total_keywords": data.get("total_keywords", 0)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
f"获取情绪词云数据失败: {str(e)}",
|
f"获取情绪词云数据失败: {str(e)}",
|
||||||
extra={"group_id": request.group_id},
|
extra={"end_user_id": request.end_user_id},
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -159,21 +159,21 @@ async def get_emotion_health(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"time_range": request.time_range
|
"time_range": request.time_range
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用服务层
|
# 调用服务层
|
||||||
data = await emotion_service.calculate_emotion_health_index(
|
data = await emotion_service.calculate_emotion_health_index(
|
||||||
end_user_id=request.group_id,
|
end_user_id=request.end_user_id,
|
||||||
time_range=request.time_range
|
time_range=request.time_range
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
"情绪健康指数获取成功",
|
"情绪健康指数获取成功",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"health_score": data.get("health_score", 0),
|
"health_score": data.get("health_score", 0),
|
||||||
"level": data.get("level", "未知")
|
"level": data.get("level", "未知")
|
||||||
}
|
}
|
||||||
@@ -186,7 +186,7 @@ async def get_emotion_health(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
f"获取情绪健康指数失败: {str(e)}",
|
f"获取情绪健康指数失败: {str(e)}",
|
||||||
extra={"group_id": request.group_id},
|
extra={"end_user_id": request.end_user_id},
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
|
|||||||
"""获取个性化情绪建议(从缓存读取)
|
"""获取个性化情绪建议(从缓存读取)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 包含 group_id 和可选的 config_id
|
request: 包含 end_user_id 和可选的 config_id
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
|
|
||||||
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"config_id": request.config_id
|
"config_id": request.config_id
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从缓存获取建议
|
# 从缓存获取建议
|
||||||
data = await emotion_service.get_cached_suggestions(
|
data = await emotion_service.get_cached_suggestions(
|
||||||
end_user_id=request.group_id,
|
end_user_id=request.end_user_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
# 缓存不存在或已过期
|
# 缓存不存在或已过期
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||||
extra={"group_id": request.group_id}
|
extra={"end_user_id": request.end_user_id}
|
||||||
)
|
)
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.NOT_FOUND,
|
BizCode.NOT_FOUND,
|
||||||
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
|
|||||||
api_logger.info(
|
api_logger.info(
|
||||||
"个性化建议获取成功(缓存)",
|
"个性化建议获取成功(缓存)",
|
||||||
extra={
|
extra={
|
||||||
"group_id": request.group_id,
|
"end_user_id": request.end_user_id,
|
||||||
"suggestions_count": len(data.get("suggestions", []))
|
"suggestions_count": len(data.get("suggestions", []))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
f"获取个性化建议失败: {str(e)}",
|
f"获取个性化建议失败: {str(e)}",
|
||||||
extra={"group_id": request.group_id},
|
extra={"end_user_id": request.end_user_id},
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ async def get_file_url(
|
|||||||
try:
|
try:
|
||||||
if permanent:
|
if permanent:
|
||||||
# Generate permanent URL (no expiration check)
|
# Generate permanent URL (no expiration check)
|
||||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
url = f"{server_url}/storage/permanent/{file_id}"
|
url = f"{server_url}/storage/permanent/{file_id}"
|
||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
|
|||||||
@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
|
|||||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def get_preference_tags(
|
async def get_preference_tags(
|
||||||
user_id: str,
|
end_user_id: str,
|
||||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||||
@@ -137,7 +137,7 @@ async def get_preference_tags(
|
|||||||
Get user preference tags from cache.
|
Get user preference tags from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: Target user ID
|
end_user_id: Target end user ID
|
||||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||||
tag_category: Optional category filter
|
tag_category: Optional category filter
|
||||||
start_date: Optional start date filter
|
start_date: Optional start date filter
|
||||||
@@ -146,20 +146,20 @@ async def get_preference_tags(
|
|||||||
Returns:
|
Returns:
|
||||||
List of preference tags from cache
|
List of preference tags from cache
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
validate_user_id(user_id)
|
validate_user_id(end_user_id)
|
||||||
|
|
||||||
# Create service with user-specific config
|
# Create service with user-specific config
|
||||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
# Get cached profile
|
# Get cached profile
|
||||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.NOT_FOUND,
|
BizCode.NOT_FOUND,
|
||||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
@@ -192,17 +192,17 @@ async def get_preference_tags(
|
|||||||
|
|
||||||
filtered_preferences.append(pref)
|
filtered_preferences.append(pref)
|
||||||
|
|
||||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def get_dimension_portrait(
|
async def get_dimension_portrait(
|
||||||
user_id: str,
|
end_user_id: str,
|
||||||
include_history: bool = Query(False, description="Include historical trends"),
|
include_history: bool = Query(False, description="Include historical trends"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
|
|||||||
Get user's four-dimension personality portrait from cache.
|
Get user's four-dimension personality portrait from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: Target user ID
|
end_user_id: Target end user ID
|
||||||
include_history: Whether to include historical trend data (ignored for cached data)
|
include_history: Whether to include historical trend data (ignored for cached data)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Four-dimension personality portrait from cache
|
Four-dimension personality portrait from cache
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
validate_user_id(user_id)
|
validate_user_id(end_user_id)
|
||||||
|
|
||||||
# Create service with user-specific config
|
# Create service with user-specific config
|
||||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
# Get cached profile
|
# Get cached profile
|
||||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.NOT_FOUND,
|
BizCode.NOT_FOUND,
|
||||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
|
|||||||
# Extract portrait from cache
|
# Extract portrait from cache
|
||||||
portrait = cached_profile.get("portrait", {})
|
portrait = cached_profile.get("portrait", {})
|
||||||
|
|
||||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def get_interest_area_distribution(
|
async def get_interest_area_distribution(
|
||||||
user_id: str,
|
end_user_id: str,
|
||||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
|
|||||||
Get user's interest area distribution from cache.
|
Get user's interest area distribution from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: Target user ID
|
end_user_id: Target end user ID
|
||||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Interest area distribution from cache
|
Interest area distribution from cache
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
validate_user_id(user_id)
|
validate_user_id(end_user_id)
|
||||||
|
|
||||||
# Create service with user-specific config
|
# Create service with user-specific config
|
||||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
# Get cached profile
|
# Get cached profile
|
||||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.NOT_FOUND,
|
BizCode.NOT_FOUND,
|
||||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
|
|||||||
# Extract interest areas from cache
|
# Extract interest areas from cache
|
||||||
interest_areas = cached_profile.get("interest_areas", {})
|
interest_areas = cached_profile.get("interest_areas", {})
|
||||||
|
|
||||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def get_behavior_habits(
|
async def get_behavior_habits(
|
||||||
user_id: str,
|
end_user_id: str,
|
||||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||||
@@ -309,7 +309,7 @@ async def get_behavior_habits(
|
|||||||
Get user's behavioral habits from cache.
|
Get user's behavioral habits from cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: Target user ID
|
end_user_id: Target end user ID
|
||||||
confidence_level: Filter by confidence level (high, medium, low)
|
confidence_level: Filter by confidence level (high, medium, low)
|
||||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||||
time_period: Filter by time period (current, past)
|
time_period: Filter by time period (current, past)
|
||||||
@@ -317,20 +317,20 @@ async def get_behavior_habits(
|
|||||||
Returns:
|
Returns:
|
||||||
List of behavioral habits from cache
|
List of behavioral habits from cache
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
validate_user_id(user_id)
|
validate_user_id(end_user_id)
|
||||||
|
|
||||||
# Create service with user-specific config
|
# Create service with user-specific config
|
||||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
# Get cached profile
|
# Get cached profile
|
||||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||||
|
|
||||||
if cached_profile is None:
|
if cached_profile is None:
|
||||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.NOT_FOUND,
|
BizCode.NOT_FOUND,
|
||||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||||
@@ -368,11 +368,11 @@ async def get_behavior_habits(
|
|||||||
|
|
||||||
filtered_habits.append(habit)
|
filtered_habits.append(habit)
|
||||||
|
|
||||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ async def write_server(
|
|||||||
Write service endpoint - processes write operations synchronously
|
Write service endpoint - processes write operations synchronously
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Write request containing message and group_id
|
user_input: Write request containing message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response with write operation status
|
Response with write operation status
|
||||||
@@ -160,19 +160,18 @@ async def write_server(
|
|||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
# 获取标准化的消息列表
|
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
user_input.group_id,
|
user_input.end_user_id,
|
||||||
messages_list, # 传递结构化消息列表
|
messages_list,
|
||||||
config_id,
|
config_id,
|
||||||
db,
|
db,
|
||||||
storage_type,
|
storage_type,
|
||||||
user_rag_memory_id
|
user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result, msg="写入成功")
|
return success(data=result, msg="写入成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -196,7 +195,7 @@ async def write_server_async(
|
|||||||
Async write service endpoint - enqueues write processing to Celery
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Write request containing message and group_id
|
user_input: Write request containing message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Task ID for tracking async operation
|
Task ID for tracking async operation
|
||||||
@@ -226,10 +225,10 @@ async def write_server_async(
|
|||||||
try:
|
try:
|
||||||
# 获取标准化的消息列表
|
# 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
"app.core.memory.agent.write_message",
|
||||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
|
||||||
@@ -255,16 +254,14 @@ async def read_server(
|
|||||||
- "2": Direct answer based on context
|
- "2": Direct answer based on context
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Read request with message, history, search_switch, and group_id
|
user_input: Read request with message, history, search_switch, and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response with query answer
|
Response with query answer
|
||||||
"""
|
"""
|
||||||
config_id = user_input.config_id
|
config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
@@ -279,12 +276,13 @@ async def read_server(
|
|||||||
name="USER_RAG_MERORY",
|
name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
user_input.group_id,
|
user_input.end_user_id,
|
||||||
user_input.message,
|
user_input.message,
|
||||||
user_input.history,
|
user_input.history,
|
||||||
user_input.search_switch,
|
user_input.search_switch,
|
||||||
@@ -295,17 +293,20 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
if str(user_input.search_switch) == "2":
|
||||||
retrieve_info = result['answer']
|
retrieve_info = result['answer']
|
||||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||||
query = user_input.message
|
query = user_input.message
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
|
end_user_id=user_input.end_user_id,
|
||||||
retrieve_info=retrieve_info,
|
retrieve_info=retrieve_info,
|
||||||
history=history,
|
history=history,
|
||||||
query=query,
|
query=query,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
|
if "信息不足,无法回答" in result['answer']:
|
||||||
|
result['answer']=retrieve_info
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -403,7 +404,7 @@ async def read_server_async(
|
|||||||
try:
|
try:
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.read_message",
|
"app.core.memory.agent.read_message",
|
||||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||||
config_id, storage_type, user_rag_memory_id]
|
config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Read task queued: {task.id}")
|
api_logger.info(f"Read task queued: {task.id}")
|
||||||
@@ -447,7 +448,7 @@ async def get_read_task_result(
|
|||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"result": task_result.get("result"),
|
"result": task_result.get("result"),
|
||||||
"group_id": task_result.get("group_id"),
|
"end_user_id": task_result.get("end_user_id"),
|
||||||
"elapsed_time": task_result.get("elapsed_time"),
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
},
|
},
|
||||||
@@ -524,7 +525,7 @@ async def get_write_task_result(
|
|||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"result": task_result.get("result"),
|
"result": task_result.get("result"),
|
||||||
"group_id": task_result.get("group_id"),
|
"end_user_id": task_result.get("end_user_id"),
|
||||||
"elapsed_time": task_result.get("elapsed_time"),
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
},
|
},
|
||||||
@@ -578,16 +579,16 @@ async def status_type(
|
|||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_input: Request containing user message and group_id
|
user_input: Request containing user message and end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Type classification result
|
Type classification result
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||||
try:
|
try:
|
||||||
# 获取标准化的消息列表
|
# 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
|
||||||
# 将消息列表转换为字符串用于分类
|
# 将消息列表转换为字符串用于分类
|
||||||
# 只取最后一条用户消息进行分类
|
# 只取最后一条用户消息进行分类
|
||||||
last_user_message = ""
|
last_user_message = ""
|
||||||
@@ -595,11 +596,11 @@ async def status_type(
|
|||||||
if msg.get('role') == 'user':
|
if msg.get('role') == 'user':
|
||||||
last_user_message = msg.get('content', '')
|
last_user_message = msg.get('content', '')
|
||||||
break
|
break
|
||||||
|
|
||||||
if not last_user_message:
|
if not last_user_message:
|
||||||
# 如果没有用户消息,使用所有消息的内容
|
# 如果没有用户消息,使用所有消息的内容
|
||||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||||
|
|
||||||
result = await memory_agent_service.classify_message_type(
|
result = await memory_agent_service.classify_message_type(
|
||||||
last_user_message,
|
last_user_message,
|
||||||
user_input.config_id,
|
user_input.config_id,
|
||||||
@@ -624,7 +625,7 @@ async def get_knowledge_type_stats_api(
|
|||||||
会对缺失类型补 0,返回字典形式。
|
会对缺失类型补 0,返回字典形式。
|
||||||
可选按状态过滤。
|
可选按状态过滤。
|
||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
@@ -697,7 +698,7 @@ async def get_user_profile_api(
|
|||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间下Popular Memory Tags,包含:
|
获取用户详情,包含:
|
||||||
- name: 用户名字(直接使用 end_user_id)
|
- name: 用户名字(直接使用 end_user_id)
|
||||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||||
- hot_tags: 4个热门记忆标签
|
- hot_tags: 4个热门记忆标签
|
||||||
|
|||||||
@@ -49,63 +49,134 @@ async def get_workspace_end_users(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表
|
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||||
|
|
||||||
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
优化策略:
|
||||||
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
1. 批量查询 end_users(一次查询而非循环)
|
||||||
|
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||||
|
3. RAG 模式使用批量查询(一次 SQL)
|
||||||
|
4. 只返回必要字段减少数据传输
|
||||||
|
5. 添加短期缓存减少重复查询
|
||||||
|
6. 并发执行配置查询和记忆数量查询
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||||
|
"memory_num": {"total": 数量},
|
||||||
|
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 尝试从缓存获取(30秒缓存)
|
||||||
|
cache_key = f"end_users:workspace:{workspace_id}"
|
||||||
|
try:
|
||||||
|
cached_data = await aio_redis_get(cache_key)
|
||||||
|
if cached_data:
|
||||||
|
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||||
|
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||||
|
|
||||||
# 获取当前空间类型
|
# 获取当前空间类型
|
||||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||||
|
|
||||||
|
# 获取 end_users(已优化为批量查询)
|
||||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
if not end_users:
|
||||||
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
api_logger.info("工作空间下没有宿主")
|
||||||
end_user_ids = [str(user.id) for user in end_users]
|
# 缓存空结果,避免重复查询
|
||||||
memory_configs_map = {}
|
|
||||||
if end_user_ids:
|
|
||||||
try:
|
try:
|
||||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
return success(data=[], msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
|
||||||
|
# 并发执行两个独立的查询任务
|
||||||
|
async def get_memory_configs():
|
||||||
|
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
get_end_users_connected_configs_batch,
|
||||||
|
end_user_ids, db
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||||
# 失败时使用空字典,不影响其他数据返回
|
return {}
|
||||||
|
|
||||||
|
async def get_memory_nums():
|
||||||
|
"""获取记忆数量"""
|
||||||
|
if current_workspace_type == "rag":
|
||||||
|
# RAG 模式:批量查询
|
||||||
|
try:
|
||||||
|
chunk_map = await asyncio.to_thread(
|
||||||
|
memory_dashboard_service.get_users_total_chunk_batch,
|
||||||
|
end_user_ids, db, current_user
|
||||||
|
)
|
||||||
|
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||||
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
|
elif current_workspace_type == "neo4j":
|
||||||
|
# Neo4j 模式:并发查询(带并发限制)
|
||||||
|
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||||
|
MAX_CONCURRENT_QUERIES = 10
|
||||||
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||||
|
|
||||||
|
async def get_neo4j_memory_num(end_user_id: str):
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
return await memory_storage_service.search_all(end_user_id)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||||
|
return {"total": 0}
|
||||||
|
|
||||||
|
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||||
|
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||||
|
|
||||||
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
|
# 并发执行配置查询和记忆数量查询
|
||||||
|
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||||
|
get_memory_configs(),
|
||||||
|
get_memory_nums()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建结果(优化:使用列表推导式)
|
||||||
result = []
|
result = []
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
memory_num = {}
|
|
||||||
if current_workspace_type == "neo4j":
|
|
||||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
|
||||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
|
||||||
elif current_workspace_type == "rag":
|
|
||||||
memory_num = {
|
|
||||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 从批量查询结果中获取配置信息
|
|
||||||
user_id = str(end_user.id)
|
user_id = str(end_user.id)
|
||||||
memory_config_info = memory_configs_map.get(user_id, {
|
config_info = memory_configs_map.get(user_id, {})
|
||||||
"memory_config_id": None,
|
result.append({
|
||||||
"memory_config_name": None
|
'end_user': {
|
||||||
})
|
'id': user_id,
|
||||||
|
'other_name': end_user.other_name
|
||||||
# 只保留需要的字段,移除 error 字段(如果有)
|
},
|
||||||
memory_config = {
|
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||||
"memory_config_id": memory_config_info.get("memory_config_id"),
|
'memory_config': {
|
||||||
"memory_config_name": memory_config_info.get("memory_config_name")
|
"memory_config_id": config_info.get("memory_config_id"),
|
||||||
}
|
"memory_config_name": config_info.get("memory_config_name")
|
||||||
|
|
||||||
result.append(
|
|
||||||
{
|
|
||||||
'end_user': end_user,
|
|
||||||
'memory_num': memory_num,
|
|
||||||
'memory_config': memory_config
|
|
||||||
}
|
}
|
||||||
)
|
})
|
||||||
|
|
||||||
|
# 写入缓存(30秒过期)
|
||||||
|
try:
|
||||||
|
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -33,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -83,7 +84,8 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id((config_id), db)
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
@@ -106,7 +108,7 @@ async def trigger_forgetting_cycle(
|
|||||||
# 调用服务层执行遗忘周期
|
# 调用服务层执行遗忘周期
|
||||||
report = await forget_service.trigger_forgetting_cycle(
|
report = await forget_service.trigger_forgetting_cycle(
|
||||||
db=db,
|
db=db,
|
||||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||||
max_merge_batch_size=payload.max_merge_batch_size,
|
max_merge_batch_size=payload.max_merge_batch_size,
|
||||||
min_days_since_access=payload.min_days_since_access,
|
min_days_since_access=payload.min_days_since_access,
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
@@ -128,7 +130,7 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
async def read_forgetting_config(
|
async def read_forgetting_config(
|
||||||
config_id: int,
|
config_id: UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -157,6 +159,7 @@ async def read_forgetting_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 调用服务层读取配置
|
# 调用服务层读取配置
|
||||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||||
|
|
||||||
@@ -194,6 +197,8 @@ async def update_forgetting_config(
|
|||||||
ApiResponse: 包含更新结果的响应
|
ApiResponse: 包含更新结果的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||||
|
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
@@ -236,7 +241,7 @@ async def update_forgetting_config(
|
|||||||
|
|
||||||
@router.get("/stats", response_model=ApiResponse)
|
@router.get("/stats", response_model=ApiResponse)
|
||||||
async def get_forgetting_stats(
|
async def get_forgetting_stats(
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -246,7 +251,7 @@ async def get_forgetting_stats(
|
|||||||
返回知识层节点统计、激活值分布等信息。
|
返回知识层节点统计、激活值分布等信息。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组ID(即 end_user_id,可选)
|
end_user_id: 组ID(即 end_user_id,可选)
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
|
|
||||||
@@ -254,26 +259,25 @@ async def get_forgetting_stats(
|
|||||||
ApiResponse: 包含统计信息的响应
|
ApiResponse: 包含统计信息的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
# 如果提供了 end_user_id,通过它获取 config_id
|
||||||
# 如果提供了 group_id,通过它获取 config_id
|
|
||||||
config_id = None
|
config_id = None
|
||||||
if group_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
@@ -283,14 +287,14 @@ async def get_forgetting_stats(
|
|||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||||
f"group_id={group_id}, config_id={config_id}"
|
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取统计信息
|
# 调用服务层获取统计信息
|
||||||
stats = await forget_service.get_forgetting_stats(
|
stats = await forget_service.get_forgetting_stats(
|
||||||
db=db,
|
db=db,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -324,7 +328,7 @@ async def get_forgetting_curve(
|
|||||||
ApiResponse: 包含遗忘曲线数据的响应
|
ApiResponse: 包含遗忘曲线数据的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
request.config_id = resolve_config_id((request.config_id), db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||||
|
|||||||
@@ -27,27 +27,27 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||||
def get_memory_count(
|
def get_memory_count(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Retrieve perceptual memory statistics for a user group.
|
"""Retrieve perceptual memory statistics for a user group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: ID of the user group (usually end_user_id in this context)
|
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
db: Database session
|
db: Database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Response containing memory count statistics
|
ApiResponse: Response containing memory count statistics
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
service = MemoryPerceptualService(db)
|
service = MemoryPerceptualService(db)
|
||||||
count_stats = service.get_memory_count(group_id)
|
count_stats = service.get_memory_count(end_user_id)
|
||||||
|
|
||||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||||
|
|
||||||
@@ -57,37 +57,37 @@ def get_memory_count(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(
|
return fail(
|
||||||
code=BizCode.INTERNAL_ERROR,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
msg="Failed to fetch memory statistics",
|
msg="Failed to fetch memory statistics",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||||
def get_last_visual_memory(
|
def get_last_visual_memory(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Retrieve the most recent VISION-type memory for a user.
|
"""Retrieve the most recent VISION-type memory for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: ID of the user group
|
end_user_id: ID of the user group
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
db: Database session
|
db: Database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Metadata of the latest visual memory
|
ApiResponse: Metadata of the latest visual memory
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
service = MemoryPerceptualService(db)
|
service = MemoryPerceptualService(db)
|
||||||
visual_memory = service.get_latest_visual_memory(group_id)
|
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||||
|
|
||||||
if visual_memory is None:
|
if visual_memory is None:
|
||||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||||
return success(
|
return success(
|
||||||
data=None,
|
data=None,
|
||||||
msg="No visual memory available"
|
msg="No visual memory available"
|
||||||
@@ -101,37 +101,37 @@ def get_last_visual_memory(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(
|
return fail(
|
||||||
code=BizCode.INTERNAL_ERROR,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
msg="Failed to fetch latest visual memory",
|
msg="Failed to fetch latest visual memory",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||||
def get_last_memory_listen(
|
def get_last_memory_listen(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: ID of the user group
|
end_user_id: ID of the user group
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
db: Database session
|
db: Database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Metadata of the latest audio memory
|
ApiResponse: Metadata of the latest audio memory
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
service = MemoryPerceptualService(db)
|
service = MemoryPerceptualService(db)
|
||||||
audio_memory = service.get_latest_audio_memory(group_id)
|
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||||
|
|
||||||
if audio_memory is None:
|
if audio_memory is None:
|
||||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||||
return success(
|
return success(
|
||||||
data=None,
|
data=None,
|
||||||
msg="No audio memory available"
|
msg="No audio memory available"
|
||||||
@@ -145,38 +145,38 @@ def get_last_memory_listen(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(
|
return fail(
|
||||||
code=BizCode.INTERNAL_ERROR,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
msg="Failed to fetch latest audio memory",
|
msg="Failed to fetch latest audio memory",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||||
def get_last_text_memory(
|
def get_last_text_memory(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Retrieve the most recent TEXT-type memory for a user.
|
"""Retrieve the most recent TEXT-type memory for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: ID of the user group
|
end_user_id: ID of the user group
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
db: Database session
|
db: Database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Metadata of the latest text memory
|
ApiResponse: Metadata of the latest text memory
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取最近的文本记忆
|
# 调用服务层获取最近的文本记忆
|
||||||
service = MemoryPerceptualService(db)
|
service = MemoryPerceptualService(db)
|
||||||
text_memory = service.get_latest_text_memory(group_id)
|
text_memory = service.get_latest_text_memory(end_user_id)
|
||||||
|
|
||||||
if text_memory is None:
|
if text_memory is None:
|
||||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||||
return success(
|
return success(
|
||||||
data=None,
|
data=None,
|
||||||
msg="No text memory available"
|
msg="No text memory available"
|
||||||
@@ -190,16 +190,16 @@ def get_last_text_memory(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(
|
return fail(
|
||||||
code=BizCode.INTERNAL_ERROR,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
msg="Failed to fetch latest text memory",
|
msg="Failed to fetch latest text memory",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||||
def get_memory_time_line(
|
def get_memory_time_line(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||||
@@ -209,7 +209,7 @@ def get_memory_time_line(
|
|||||||
"""Retrieve a timeline of perceptual memories for a user group.
|
"""Retrieve a timeline of perceptual memories for a user group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: ID of the user group
|
end_user_id: ID of the user group
|
||||||
perceptual_type: Optional filter for perceptual type
|
perceptual_type: Optional filter for perceptual type
|
||||||
page: Page number for pagination
|
page: Page number for pagination
|
||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
@@ -221,7 +221,7 @@ def get_memory_time_line(
|
|||||||
"""
|
"""
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -232,7 +232,7 @@ def get_memory_time_line(
|
|||||||
)
|
)
|
||||||
|
|
||||||
service = MemoryPerceptualService(db)
|
service = MemoryPerceptualService(db)
|
||||||
timeline_data = service.get_time_line(group_id, query)
|
timeline_data = service.get_time_line(end_user_id, query)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||||
@@ -246,7 +246,7 @@ def get_memory_time_line(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(
|
api_logger.error(
|
||||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||||
f"error={str(e)}"
|
f"error={str(e)}"
|
||||||
)
|
)
|
||||||
return fail(
|
return fail(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||||
@@ -11,7 +12,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories.data_config_repository import DataConfigRepository
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||||
from app.services.memory_reflection_service import (
|
from app.services.memory_reflection_service import (
|
||||||
@@ -24,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -42,6 +45,7 @@ async def save_reflection_config(
|
|||||||
"""Save reflection configuration to data_comfig table"""
|
"""Save reflection configuration to data_comfig table"""
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
if not config_id:
|
if not config_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@@ -50,7 +54,7 @@ async def save_reflection_config(
|
|||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
data_config = DataConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
db,
|
db,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
enable_self_reflexion=request.reflection_enabled,
|
enable_self_reflexion=request.reflection_enabled,
|
||||||
@@ -63,17 +67,17 @@ async def save_reflection_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(data_config)
|
db.refresh(memory_config)
|
||||||
|
|
||||||
reflection_result={
|
reflection_result={
|
||||||
"config_id": data_config.config_id,
|
"config_id": memory_config.config_id,
|
||||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||||
"iteration_period": data_config.iteration_period,
|
"iteration_period": memory_config.iteration_period,
|
||||||
"reflexion_range": data_config.reflexion_range,
|
"reflexion_range": memory_config.reflexion_range,
|
||||||
"baseline": data_config.baseline,
|
"baseline": memory_config.baseline,
|
||||||
"reflection_model_id": data_config.reflection_model_id,
|
"reflection_model_id": memory_config.reflection_model_id,
|
||||||
"memory_verify": data_config.memory_verify,
|
"memory_verify": memory_config.memory_verify,
|
||||||
"quality_assessment": data_config.quality_assessment}
|
"quality_assessment": memory_config.quality_assessment}
|
||||||
|
|
||||||
return success(data=reflection_result, msg="反思配置成功")
|
return success(data=reflection_result, msg="反思配置成功")
|
||||||
|
|
||||||
@@ -111,14 +115,14 @@ async def start_workspace_reflection(
|
|||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
if data['data_configs'] == []:
|
if data['memory_configs'] == []:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
releases = data['releases']
|
releases = data['releases']
|
||||||
data_configs = data['data_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
for base, config, user in zip(releases, data_configs, end_users):
|
for base, config, user in zip(releases, memory_configs, end_users):
|
||||||
# 安全地转换为整数,处理空字符串和None的情况
|
# 安全地转换为整数,处理空字符串和None的情况
|
||||||
print(base['config'])
|
print(base['config'])
|
||||||
try:
|
try:
|
||||||
@@ -156,17 +160,20 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
@router.get("/reflection/configs")
|
@router.get("/reflection/configs")
|
||||||
async def start_reflection_configs(
|
async def start_reflection_configs(
|
||||||
config_id: int,
|
config_id: uuid.UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""通过config_id查询data_config表中的反思配置信息"""
|
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
try:
|
try:
|
||||||
|
config_id=resolve_config_id(config_id,db)
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
|
memory_config_id = resolve_config_id(result.config_id, db)
|
||||||
# 构建返回数据
|
# 构建返回数据
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": result.config_id,
|
"config_id": memory_config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
"reflection_period_in_hours": result.iteration_period,
|
"reflection_period_in_hours": result.iteration_period,
|
||||||
"reflexion_range": result.reflexion_range,
|
"reflexion_range": result.reflexion_range,
|
||||||
@@ -191,7 +198,7 @@ async def start_reflection_configs(
|
|||||||
|
|
||||||
@router.get("/reflection/run")
|
@router.get("/reflection/run")
|
||||||
async def reflection_run(
|
async def reflection_run(
|
||||||
config_id: int,
|
config_id: UUID|int,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -199,9 +206,9 @@ async def reflection_run(
|
|||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""Activate the reflection function for all matching applications in the workspace"""
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 使用DataConfigRepository查询反思配置
|
# 使用MemoryConfigRepository查询反思配置
|
||||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
@@ -34,6 +35,8 @@ from fastapi import APIRouter, Depends
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# Get API logger
|
# Get API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -140,7 +143,6 @@ def create_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||||
@@ -160,12 +162,12 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: str,
|
config_id: UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -187,7 +189,7 @@ def update_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||||
@@ -210,7 +212,7 @@ def update_config_extracted(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||||
@@ -232,12 +234,12 @@ def update_config_extracted(
|
|||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: str,
|
config_id: UUID | int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||||
@@ -285,6 +287,7 @@ async def pilot_run(
|
|||||||
f"Pilot run requested: config_id={payload.config_id}, "
|
f"Pilot run requested: config_id={payload.config_id}, "
|
||||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||||
)
|
)
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
svc.pilot_run_stream(payload),
|
svc.pilot_run_stream(payload),
|
||||||
@@ -420,15 +423,95 @@ async def get_hot_memory_tags_api(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
"""
|
||||||
|
获取热门记忆标签(带Redis缓存)
|
||||||
|
|
||||||
|
缓存策略:
|
||||||
|
- 缓存键:workspace_id + limit
|
||||||
|
- 过期时间:5分钟(300秒)
|
||||||
|
- 缓存命中:~50ms
|
||||||
|
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 构建缓存键
|
||||||
|
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||||
|
|
||||||
|
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 尝试从Redis缓存获取
|
||||||
|
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||||
|
import json
|
||||||
|
|
||||||
|
cached_result = await aio_redis_get(cache_key)
|
||||||
|
if cached_result:
|
||||||
|
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||||
|
try:
|
||||||
|
data = json.loads(cached_result)
|
||||||
|
return success(data=data, msg="查询成功(缓存)")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||||
|
|
||||||
|
# 缓存未命中,执行查询
|
||||||
|
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||||
|
|
||||||
|
# 写入缓存(过期时间:5分钟)
|
||||||
|
# 注意:result是列表,需要转换为JSON字符串
|
||||||
|
try:
|
||||||
|
cache_data = json.dumps(result, ensure_ascii=False)
|
||||||
|
await aio_redis_set(cache_key, cache_data, expire=300)
|
||||||
|
api_logger.info(f"Cached result for key: {cache_key}")
|
||||||
|
except Exception as cache_error:
|
||||||
|
# 缓存写入失败不影响主流程
|
||||||
|
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||||
|
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||||
|
async def clear_hot_memory_tags_cache(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
清除热门标签缓存
|
||||||
|
|
||||||
|
用于:
|
||||||
|
- 手动刷新数据
|
||||||
|
- 调试和测试
|
||||||
|
- 数据更新后立即生效
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.aioRedis import aio_redis_delete
|
||||||
|
|
||||||
|
# 清除所有limit的缓存(常见的limit值)
|
||||||
|
cleared_count = 0
|
||||||
|
for limit in [5, 10, 15, 20, 30, 50]:
|
||||||
|
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||||
|
result = await aio_redis_delete(cache_key)
|
||||||
|
if result:
|
||||||
|
cleared_count += 1
|
||||||
|
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={"cleared_count": cleared_count},
|
||||||
|
msg=f"成功清除 {cleared_count} 个缓存"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
|||||||
@@ -20,18 +20,18 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||||
def get_memory_count(
|
def get_memory_count(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
group_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -39,7 +39,7 @@ def get_conversations(
|
|||||||
Retrieve all conversations for the current user in a specific group.
|
Retrieve all conversations for the current user in a specific group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ def get_conversations(
|
|||||||
"""
|
"""
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations = conversation_service.get_user_conversations(
|
conversations = conversation_service.get_user_conversations(
|
||||||
group_id
|
end_user_id
|
||||||
)
|
)
|
||||||
return success(data=[
|
return success(data=[
|
||||||
{
|
{
|
||||||
@@ -63,7 +63,7 @@ def get_conversations(
|
|||||||
], msg="get conversations success")
|
], msg="get conversations success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
def get_messages(
|
def get_messages(
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -100,7 +100,7 @@ def get_messages(
|
|||||||
return success(data=messages, msg="get conversation history success")
|
return success(data=messages, msg="get conversation history success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||||
async def get_conversation_detail(
|
async def get_conversation_detail(
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
|||||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.repositories.model_repository import ModelConfigRepository
|
||||||
from app.schemas import model_schema
|
from app.schemas import model_schema
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.schemas.response_schema import ApiResponse, PageData
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
@@ -24,24 +26,83 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/type", response_model=ApiResponse)
|
@router.get("/type", response_model=ApiResponse)
|
||||||
def get_model_types():
|
def get_model_types():
|
||||||
|
|
||||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/provider", response_model=ApiResponse)
|
@router.get("/provider", response_model=ApiResponse)
|
||||||
def get_model_providers():
|
def get_model_providers():
|
||||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
return success(msg="获取模型提供商成功", data=providers)
|
||||||
|
|
||||||
|
@router.get("/strategy", response_model=ApiResponse)
|
||||||
|
def get_model_strategies():
|
||||||
|
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取模型配置列表
|
||||||
|
|
||||||
|
支持多个 type 参数:
|
||||||
|
- 单个:?type=LLM
|
||||||
|
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||||
|
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析 type 参数(支持逗号分隔)
|
||||||
|
type_list = []
|
||||||
|
if type is not None:
|
||||||
|
flat_type = []
|
||||||
|
for item in type:
|
||||||
|
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||||
|
flat_type.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
|
query = model_schema.ModelConfigQuery(
|
||||||
|
type=type_list,
|
||||||
|
provider=provider,
|
||||||
|
is_active=is_active,
|
||||||
|
is_public=is_public,
|
||||||
|
search=search,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||||
|
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||||
|
result = PageData.model_validate(result_orm)
|
||||||
|
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||||
|
return success(data=result, msg="模型配置列表获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/new", response_model=ApiResponse)
|
||||||
|
def get_model_list_new(
|
||||||
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -53,36 +114,127 @@ def get_model_list(
|
|||||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析 type 参数(支持逗号分隔)
|
# 解析 type 参数(支持逗号分隔)
|
||||||
type_list = None
|
type_list = []
|
||||||
if type:
|
if type is not None:
|
||||||
type_values = [t.strip() for t in type.split(',')]
|
flat_type = []
|
||||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
for item in type:
|
||||||
|
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||||
|
flat_type.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.info(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQueryNew(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
is_composite=is_composite,
|
||||||
page=page,
|
search=search
|
||||||
pagesize=pagesize
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||||
result = PageData.model_validate(result_orm)
|
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
|
||||||
return success(data=result, msg="模型配置列表获取成功")
|
return success(data=result, msg="模型配置列表获取成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/model_plaza", response_model=ApiResponse)
|
||||||
|
def get_model_plaza_list(
|
||||||
|
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||||
|
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||||
|
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||||
|
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""模型广场查询接口(按供应商分组)"""
|
||||||
|
|
||||||
|
query = model_schema.ModelBaseQuery(
|
||||||
|
type=type,
|
||||||
|
provider=provider,
|
||||||
|
is_official=is_official,
|
||||||
|
is_deprecated=is_deprecated,
|
||||||
|
search=search
|
||||||
|
)
|
||||||
|
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||||
|
return success(data=result, msg="模型广场列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||||
|
def get_model_base_by_id(
|
||||||
|
model_base_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取基础模型详情"""
|
||||||
|
|
||||||
|
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||||
|
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/model_plaza", response_model=ApiResponse)
|
||||||
|
def create_model_base(
|
||||||
|
data: model_schema.ModelBaseCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""创建基础模型"""
|
||||||
|
|
||||||
|
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||||
|
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||||
|
def update_model_base(
|
||||||
|
model_base_id: uuid.UUID,
|
||||||
|
data: model_schema.ModelBaseUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新基础模型"""
|
||||||
|
|
||||||
|
# 不允许更改type类型
|
||||||
|
if data.type is not None or data.provider is not None:
|
||||||
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||||
|
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||||
|
def delete_model_base(
|
||||||
|
model_base_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除基础模型"""
|
||||||
|
|
||||||
|
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||||
|
return success(msg="基础模型删除成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||||
|
def add_model_from_plaza(
|
||||||
|
model_base_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""从模型广场添加模型到模型列表"""
|
||||||
|
|
||||||
|
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||||
|
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{model_id}", response_model=ApiResponse)
|
@router.get("/{model_id}", response_model=ApiResponse)
|
||||||
def get_model_by_id(
|
def get_model_by_id(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -138,6 +290,73 @@ async def create_model(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/composite", response_model=ApiResponse)
|
||||||
|
async def create_composite_model(
|
||||||
|
model_data: model_schema.CompositeModelCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建组合模型
|
||||||
|
|
||||||
|
- 绑定一个或多个现有的 API Key
|
||||||
|
- 所有 API Key 必须来自非组合模型
|
||||||
|
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||||
|
"""
|
||||||
|
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||||
|
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||||
|
|
||||||
|
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
|
return success(data=result, msg="组合模型创建成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||||
|
async def update_composite_model(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
model_data: model_schema.CompositeModelCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新组合模型"""
|
||||||
|
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_data.type is not None:
|
||||||
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
|
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||||
|
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||||
|
|
||||||
|
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
|
return success(data=result, msg="组合模型更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||||
|
def delete_composite_model(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除组合模型"""
|
||||||
|
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||||
|
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||||
|
return success(msg="组合模型删除成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{model_id}", response_model=ApiResponse)
|
@router.put("/{model_id}", response_model=ApiResponse)
|
||||||
def update_model(
|
def update_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -214,6 +433,53 @@ def get_model_api_keys(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||||
|
async def create_model_api_key_by_provider(
|
||||||
|
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据供应商为所有匹配的模型创建API Key
|
||||||
|
"""
|
||||||
|
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 根据tenant_id和provider筛选model_config_id列表
|
||||||
|
model_config_ids = api_key_data.model_config_ids
|
||||||
|
if not model_config_ids:
|
||||||
|
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||||
|
db=db,
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
provider=api_key_data.provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_config_ids:
|
||||||
|
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
|
# 构造schema并调用service
|
||||||
|
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||||
|
provider=api_key_data.provider,
|
||||||
|
api_key=api_key_data.api_key,
|
||||||
|
api_base=api_key_data.api_base,
|
||||||
|
description=api_key_data.description,
|
||||||
|
config=api_key_data.config,
|
||||||
|
is_active=api_key_data.is_active,
|
||||||
|
priority=api_key_data.priority,
|
||||||
|
model_config_ids=model_config_ids
|
||||||
|
)
|
||||||
|
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||||
|
|
||||||
|
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||||
|
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||||
|
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||||
|
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||||
|
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_model_api_key(
|
async def create_model_api_key(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -228,11 +494,12 @@ async def create_model_api_key(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 设置模型配置ID
|
# 设置模型配置ID
|
||||||
api_key_data.model_config_id = model_id
|
api_key_data.model_config_ids = [model_id]
|
||||||
|
|
||||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||||
|
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||||
return success(data=result, msg="模型API Key创建成功")
|
return success(data=result, msg="模型API Key创建成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||||
@@ -334,5 +601,3 @@ async def validate_model_config(
|
|||||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -317,9 +317,12 @@ async def chat(
|
|||||||
appid = share.app_id
|
appid = share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
|
|
||||||
# 直接通过 SQLAlchemy 查询 app
|
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
app = db.query(App).filter(App.id == appid).first()
|
app = db.query(App).filter(
|
||||||
|
App.id == appid,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).first()
|
||||||
if not app:
|
if not app:
|
||||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||||
|
|
||||||
|
|||||||
@@ -235,11 +235,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=new_end_user.id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=web_search,
|
||||||
memory=payload.memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
@@ -268,11 +268,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=new_end_user.id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=web_search,
|
||||||
memory=payload.memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
|||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
|||||||
@@ -135,27 +135,27 @@ async def generate_cache_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
group_id = request.end_user_id
|
end_user_id = request.end_user_id
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if group_id:
|
if end_user_id:
|
||||||
# 为单个用户生成
|
# 为单个用户生成
|
||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
"end_user_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"insight_success": insight_result["success"],
|
"insight_success": insight_result["success"],
|
||||||
"summary_success": summary_result["success"],
|
"summary_success": summary_result["success"],
|
||||||
"errors": []
|
"errors": []
|
||||||
@@ -175,9 +175,9 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
# 记录结果
|
# 记录结果
|
||||||
if result["insight_success"] and result["summary_success"]:
|
if result["insight_success"] and result["summary_success"]:
|
||||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||||
else:
|
else:
|
||||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||||
|
|
||||||
return success(data=result, msg="生成完成")
|
return success(data=result, msg="生成完成")
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async def create_workflow_config(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == app_id,
|
App.id == app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -214,7 +214,7 @@ async def delete_workflow_config(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == app_id,
|
App.id == app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -259,7 +259,7 @@ async def validate_workflow_config(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == app_id,
|
App.id == app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -329,7 +329,7 @@ async def get_workflow_executions(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == app_id,
|
App.id == app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -389,7 +389,7 @@ async def get_workflow_execution(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == execution.app_id,
|
App.id == execution.app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -440,7 +440,7 @@ async def run_workflow(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == app_id,
|
App.id == app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
|
|||||||
app = db.query(App).filter(
|
app = db.query(App).filter(
|
||||||
App.id == execution.app_id,
|
App.id == execution.app_id,
|
||||||
App.workspace_id == current_user.current_workspace_id,
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
App.is_active == True
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from langchain.agents import create_agent
|
|||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -155,13 +157,13 @@ class LangChainAgent:
|
|||||||
# userid=end_user_end,
|
# userid=end_user_end,
|
||||||
# messages=messages,
|
# messages=messages,
|
||||||
# apply_id=end_user_end,
|
# apply_id=end_user_end,
|
||||||
# group_id=end_user_end,
|
# end_user_id=end_user_end,
|
||||||
# aimessages=aimessages
|
# aimessages=aimessages
|
||||||
# )
|
# )
|
||||||
# store.delete_duplicate_sessions()
|
# store.delete_duplicate_sessions()
|
||||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||||
# return session_id
|
# return session_id
|
||||||
|
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
# async def term_memory_redis_read(self,end_user_end):
|
# async def term_memory_redis_read(self,end_user_end):
|
||||||
# end_user_end = f"Term_{end_user_end}"
|
# end_user_end = f"Term_{end_user_end}"
|
||||||
@@ -175,11 +177,10 @@ class LangChainAgent:
|
|||||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||||
# retrieved_content.append({query: aimessages})
|
# retrieved_content.append({query: aimessages})
|
||||||
# return messagss_list,retrieved_content
|
# return messagss_list,retrieved_content
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||||
"""
|
"""
|
||||||
写入记忆(支持结构化消息)
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
@@ -188,7 +189,7 @@ class LangChainAgent:
|
|||||||
user_rag_memory_id: RAG 记忆ID
|
user_rag_memory_id: RAG 记忆ID
|
||||||
actual_end_user_id: 实际用户ID
|
actual_end_user_id: 实际用户ID
|
||||||
actual_config_id: 配置ID
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
逻辑说明:
|
逻辑说明:
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
- Neo4j 模式:使用结构化消息列表
|
- Neo4j 模式:使用结构化消息列表
|
||||||
@@ -196,48 +197,54 @@ class LangChainAgent:
|
|||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
"""
|
"""
|
||||||
if storage_type == "rag":
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
else:
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if user_message:
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if ai_message:
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 调用 Celery 任务,传递结构化消息列表
|
|
||||||
# 数据流:
|
|
||||||
# 1. structured_messages 传递给 write_message_task
|
|
||||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
|
||||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
|
||||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
|
||||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
|
||||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
|
||||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # group_id: 用户ID
|
|
||||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
actual_config_id, # config_id: 配置ID
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
actual_config_id=resolve_config_id(actual_config_id, db)
|
||||||
|
|
||||||
|
if storage_type == "rag":
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
else:
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if user_message:
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if ai_message:
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 调用 Celery 任务,传递结构化消息列表
|
||||||
|
# 数据流:
|
||||||
|
# 1. structured_messages 传递给 write_message_task
|
||||||
|
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||||
|
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||||
|
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||||
|
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||||
|
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||||
|
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
actual_config_id, # config_id: 配置ID
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
|
|||||||
@@ -9,6 +9,25 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
|
# ========================================================================
|
||||||
|
# Deployment Mode Configuration
|
||||||
|
# ========================================================================
|
||||||
|
# community: 社区版(开源,功能受限)
|
||||||
|
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||||
|
# enterprise: 企业私有化版(License 控制)
|
||||||
|
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||||
|
|
||||||
|
# License 配置(企业版)
|
||||||
|
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||||
|
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||||
|
|
||||||
|
# 计费服务配置(SaaS 版)
|
||||||
|
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||||
|
|
||||||
|
# 基础 URL(用于 SSO 回调等)
|
||||||
|
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||||
|
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||||
|
|
||||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||||
# API Keys Configuration
|
# API Keys Configuration
|
||||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||||
@@ -72,6 +91,10 @@ class Settings:
|
|||||||
|
|
||||||
# Single Sign-On configuration
|
# Single Sign-On configuration
|
||||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||||
|
|
||||||
|
# SSO 免登配置
|
||||||
|
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||||
|
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
@@ -107,6 +130,7 @@ class Settings:
|
|||||||
|
|
||||||
# Server Configuration
|
# Server Configuration
|
||||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||||
|
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# Internal Configuration (not in .env, used by application code)
|
# Internal Configuration (not in .env, used by application code)
|
||||||
@@ -184,7 +208,7 @@ class Settings:
|
|||||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||||
|
|
||||||
# official environment system version
|
# official environment system version
|
||||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||||
|
|
||||||
# workflow config
|
# workflow config
|
||||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
"""问题分解节点"""
|
"""问题分解节点"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
data = state.get('spit_data', '')['context']
|
data = state.get('spit_data', '')['context']
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
databasets = {}
|
databasets = {}
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
|||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
async def rag_knowledge(state,question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension=state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type=state.get('storage_type', '')
|
||||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||||
group_id=state.get('group_id', '')
|
end_user_id=state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original=state.get('data', '')
|
original=state.get('data', '')
|
||||||
problem_list=[]
|
problem_list=[]
|
||||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
try:
|
try:
|
||||||
# Prepare search parameters based on storage type
|
# Prepare search parameters based on storage type
|
||||||
search_params = {
|
search_params = {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"question": question,
|
"question": question,
|
||||||
"return_raw_results": True
|
"return_raw_results": True
|
||||||
}
|
}
|
||||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取group_id
|
# 从state中获取end_user_id
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start=time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original = state.get('data', '')
|
original = state.get('data', '')
|
||||||
problem_list = []
|
problem_list = []
|
||||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# 创建异步任务处理单个问题
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
|
|
||||||
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
|||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
group_id = state.get("group_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||||
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
group_id = state.get("group_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
user_id=group_id,
|
user_id=end_user_id,
|
||||||
query=data,
|
query=data,
|
||||||
apply_id=group_id,
|
apply_id=end_user_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
ai_response=aimessages
|
ai_response=aimessages
|
||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||||
data=state.get("data", '')
|
data=state.get("data", '')
|
||||||
group_id=state.get("group_id", '')
|
end_user_id=state.get("end_user_id", '')
|
||||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
history = await summary_history( state)
|
history = await summary_history( state)
|
||||||
search_params = {
|
search_params = {
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True,
|
"return_raw_results": True,
|
||||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||||
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|||||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||||
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
aimessages=await summary_llm(state,history,data,
|
aimessages=await summary_llm(state,history,data,
|
||||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||||
|
|
||||||
|
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
async def Summary_fails(state: ReadState)-> ReadState:
|
||||||
storage_type=state.get("storage_type", '')
|
storage_type=state.get("storage_type", '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||||
|
history = await summary_history(state)
|
||||||
|
query = state.get("data", '')
|
||||||
|
verify = state.get("verify", '')
|
||||||
|
verify_expansion_issue = verify.get("verified_data", '')
|
||||||
|
retrieve_info_str = ''
|
||||||
|
for data in verify_expansion_issue:
|
||||||
|
for key, value in data.items():
|
||||||
|
if key == 'answer_small':
|
||||||
|
for i in value:
|
||||||
|
retrieve_info_str += i + '\n'
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"history": history,
|
||||||
|
"retrieve_info": retrieve_info_str
|
||||||
|
}
|
||||||
|
aimessages = await summary_llm(state, history, data,
|
||||||
|
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
result= {
|
result= {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": "没有相关数据",
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
|||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
|||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||||
|
|
||||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
|
|||||||
@@ -1,23 +1,24 @@
|
|||||||
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
|
||||||
from app.core.memory.agent.utils.write_tools import write
|
from app.core.memory.agent.utils.write_tools import write
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def write_node(state: WriteState) -> WriteState:
|
async def write_node(state: WriteState) -> WriteState:
|
||||||
"""
|
"""
|
||||||
Write data to the database/file system.
|
Write data to the database/file system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: WriteState containing messages, group_id, and memory_config
|
state: WriteState containing messages, end_user_id, and memory_config
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Contains 'write_result' with status and data fields
|
dict: Contains 'write_result' with status and data fields
|
||||||
"""
|
"""
|
||||||
messages = state.get('messages', [])
|
messages = state.get('messages', [])
|
||||||
group_id = state.get('group_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', '')
|
memory_config = state.get('memory_config', '')
|
||||||
|
|
||||||
# Convert LangChain messages to structured format expected by write()
|
# Convert LangChain messages to structured format expected by write()
|
||||||
structured_messages = []
|
structured_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
|
|||||||
"role": role,
|
"role": role,
|
||||||
"content": msg.content # content is now guaranteed to be a string
|
"content": msg.content # content is now guaranteed to be a string
|
||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await write(
|
result = await write(
|
||||||
messages=structured_messages,
|
messages=structured_messages,
|
||||||
user_id=group_id,
|
end_user_id=end_user_id,
|
||||||
apply_id=group_id,
|
|
||||||
group_id=group_id,
|
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
)
|
)
|
||||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ async def make_read_graph():
|
|||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "昨天有什么好看的电影"
|
message = "昨天有什么好看的电影"
|
||||||
group_id = '88a459f5_text09' # 组ID
|
end_user_id = '88a459f5_text09' # 组ID
|
||||||
storage_type = 'neo4j' # 存储类型
|
storage_type = 'neo4j' # 存储类型
|
||||||
search_switch = '1' # 搜索开关
|
search_switch = '1' # 搜索开关
|
||||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||||
@@ -95,9 +95,9 @@ async def main():
|
|||||||
start=time.time()
|
start=time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
|
|||||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
|||||||
class TimeRetrievalInput(BaseModel):
|
class TimeRetrievalInput(BaseModel):
|
||||||
"""时间检索工具的输入模式"""
|
"""时间检索工具的输入模式"""
|
||||||
context: str = Field(description="用户输入的查询内容")
|
context: str = Field(description="用户输入的查询内容")
|
||||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||||
|
|
||||||
def create_time_retrieval_tool(group_id: str):
|
def create_time_retrieval_tool(end_user_id: str):
|
||||||
"""
|
"""
|
||||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_temporal_result_fields(data):
|
def clean_temporal_result_fields(data):
|
||||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||||
显式接收参数:
|
显式接收参数:
|
||||||
- context: 查询上下文内容
|
- context: 查询上下文内容
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
- clean_output: 是否清理输出中的元数据字段
|
||||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||||
"""
|
"""
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 使用传入的参数或默认值
|
# 使用传入的参数或默认值
|
||||||
actual_group_id = group_id_param or group_id
|
actual_end_user_id = end_user_id_param or end_user_id
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 基本时间搜索
|
# 基本时间搜索
|
||||||
results = await search_by_temporal(
|
results = await search_by_temporal(
|
||||||
group_id=actual_group_id,
|
end_user_id=actual_end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
end_date=actual_end_date,
|
end_date=actual_end_date,
|
||||||
limit=10
|
limit=10
|
||||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
|||||||
# 关键词时间搜索
|
# 关键词时间搜索
|
||||||
results = await search_by_keyword_temporal(
|
results = await search_by_keyword_temporal(
|
||||||
query_text=context,
|
query_text=context,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
end_date=actual_end_date,
|
end_date=actual_end_date,
|
||||||
limit=15
|
limit=15
|
||||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
**search_params: 搜索参数,包含group_id, limit, include等
|
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_result_fields(data):
|
def clean_result_fields(data):
|
||||||
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
group_id: str = None,
|
end_user_id: str = None,
|
||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
context: 查询内容
|
context: 查询内容
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: 结果数量限制
|
||||||
group_id: 组ID,用于过滤搜索结果
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
rerank_alpha: 重排序权重参数
|
rerank_alpha: 重排序权重参数
|
||||||
use_forgetting_rerank: 是否使用遗忘重排序
|
use_forgetting_rerank: 是否使用遗忘重排序
|
||||||
use_llm_rerank: 是否使用LLM重排序
|
use_llm_rerank: 是否使用LLM重排序
|
||||||
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
final_params = {
|
final_params = {
|
||||||
"query_text": context,
|
"query_text": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"group_id": group_id or search_params.get("group_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||||
"output_path": None, # 不保存到文件
|
"output_path": None, # 不保存到文件
|
||||||
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
group_id: str = None,
|
end_user_id: str = None,
|
||||||
clean_output: bool = True
|
clean_output: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
context: 查询内容
|
context: 查询内容
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: 结果数量限制
|
||||||
group_id: 组ID,用于过滤搜索结果
|
end_user_id: 组ID,用于过滤搜索结果
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: 是否清理输出中的元数据字段
|
||||||
"""
|
"""
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
|||||||
"context": context,
|
"context": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"clean_output": clean_output
|
"clean_output": clean_output
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.db import get_db
|
|||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
@@ -26,9 +27,21 @@ async def make_write_graph():
|
|||||||
"""
|
"""
|
||||||
Create a write graph workflow for memory operations.
|
Create a write graph workflow for memory operations.
|
||||||
|
|
||||||
The workflow directly processes messages from the initial state
|
Args:
|
||||||
and saves them to Neo4j storage.
|
user_id: User identifier
|
||||||
|
tools: MCP tools loaded from session
|
||||||
|
apply_id: Application identifier
|
||||||
|
end_user_id: Group identifier
|
||||||
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
|
# workflow = StateGraph(WriteState)
|
||||||
|
# workflow.add_node("content_input", content_input_write)
|
||||||
|
# workflow.add_node("save_neo4j", write_node)
|
||||||
|
# workflow.add_edge(START, "content_input")
|
||||||
|
# workflow.add_edge("content_input", "save_neo4j")
|
||||||
|
# workflow.add_edge("save_neo4j", END)
|
||||||
|
#
|
||||||
|
# graph = workflow.compile()
|
||||||
workflow = StateGraph(WriteState)
|
workflow = StateGraph(WriteState)
|
||||||
workflow.add_node("save_neo4j", write_node)
|
workflow.add_node("save_neo4j", write_node)
|
||||||
workflow.add_edge(START, "save_neo4j")
|
workflow.add_edge(START, "save_neo4j")
|
||||||
@@ -42,7 +55,7 @@ async def make_write_graph():
|
|||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "今天周一"
|
message = "今天周一"
|
||||||
group_id = 'new_2025test1103' # 组ID
|
end_user_id = 'new_2025test1103' # 组ID
|
||||||
|
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
@@ -54,9 +67,9 @@ async def main():
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async with make_write_graph() as graph:
|
async with make_write_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||||
|
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
async for update_event in graph.astream(
|
async for update_event in graph.astream(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None
|
user_rag_memory_id: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
|||||||
tool_call_id: Extracted tool call identifier
|
tool_call_id: Extracted tool call identifier
|
||||||
search_switch: Search routing parameter
|
search_switch: Search routing parameter
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
storage_type: Storage type for the workspace (optional)
|
storage_type: Storage type for the workspace (optional)
|
||||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
|||||||
base_args = {
|
base_args = {
|
||||||
"usermessages": tool_call_id,
|
"usermessages": tool_call_id,
|
||||||
"apply_id": apply_id,
|
"apply_id": apply_id,
|
||||||
"group_id": group_id
|
"end_user_id": end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class SearchService:
|
|||||||
|
|
||||||
async def execute_hybrid_search(
|
async def execute_hybrid_search(
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
question: str,
|
question: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
@@ -105,7 +105,7 @@ class SearchService:
|
|||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: Group identifier for filtering results
|
end_user_id: Group identifier for filtering results
|
||||||
question: Search query text
|
question: Search query text
|
||||||
limit: Maximum number of results to return (default: 5)
|
limit: Maximum number of results to return (default: 5)
|
||||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||||
@@ -130,7 +130,7 @@ class SearchService:
|
|||||||
answer = await run_hybrid_search(
|
answer = await run_hybrid_search(
|
||||||
query_text=cleaned_query,
|
query_text=cleaned_query,
|
||||||
search_type=search_type,
|
search_type=search_type,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
@@ -186,7 +186,7 @@ class SearchService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty results on failure
|
# Return empty results on failure
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class SessionService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str
|
end_user_id: str
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve conversation history from Redis.
|
Retrieve conversation history from Redis.
|
||||||
@@ -67,20 +67,20 @@ class SessionService:
|
|||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of conversation history items with Query and Answer keys
|
List of conversation history items with Query and Answer keys
|
||||||
Returns empty list if no history found or on error
|
Returns empty list if no history found or on error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||||
|
|
||||||
# Validate history structure
|
# Validate history structure
|
||||||
if not isinstance(history, list):
|
if not isinstance(history, list):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid history format for user {user_id}, "
|
f"Invalid history format for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ class SessionService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to retrieve history for user {user_id}, "
|
f"Failed to retrieve history for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: {e}",
|
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty list on error to allow execution to continue
|
# Return empty list on error to allow execution to continue
|
||||||
@@ -100,7 +100,7 @@ class SessionService:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
ai_response: str
|
ai_response: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +110,7 @@ class SessionService:
|
|||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
query: User query/message
|
query: User query/message
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
ai_response: AI response/answer
|
ai_response: AI response/answer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -131,7 +131,7 @@ class SessionService:
|
|||||||
userid=user_id,
|
userid=user_id,
|
||||||
messages=query,
|
messages=query,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
aimessages=ai_response
|
aimessages=ai_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class SessionService:
|
|||||||
Duplicates are identified by matching:
|
Duplicates are identified by matching:
|
||||||
- sessionid
|
- sessionid
|
||||||
- user_id (id field)
|
- user_id (id field)
|
||||||
- group_id
|
- end_user_id
|
||||||
- messages
|
- messages
|
||||||
- aimessages
|
- aimessages
|
||||||
|
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
|||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
user_id: str = "user1",
|
|
||||||
apply_id: str = "applyid",
|
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "wyl_20251027",
|
||||||
config_id: str = None
|
config_id: str = None
|
||||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
user_id: User identifier
|
|
||||||
apply_id: Application identifier
|
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing
|
config_id: Configuration ID for processing
|
||||||
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
|
|||||||
"""
|
"""
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||||
raise ValueError("messages parameter must be a non-empty list")
|
raise ValueError("messages parameter must be a non-empty list")
|
||||||
|
|
||||||
conversation_messages = []
|
conversation_messages = []
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
|
|
||||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
|
||||||
apply_id=apply_id,
|
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
chunker = DialogueChunker(chunker_strategy)
|
chunker = DialogueChunker(chunker_strategy)
|
||||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||||
dialog_data.chunks = extracted_chunks
|
dialog_data.chunks = extracted_chunks
|
||||||
|
|
||||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||||
|
|
||||||
return [dialog_data]
|
return [dialog_data]
|
||||||
|
|||||||
@@ -1,24 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
from typing import Annotated, TypedDict
|
from typing import Annotated, TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import AnyMessage
|
from langchain_core.messages import AnyMessage
|
||||||
from langgraph.graph import add_messages
|
from langgraph.graph import add_messages
|
||||||
|
|
||||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||||
|
|
||||||
class WriteState(TypedDict):
|
class WriteState(TypedDict):
|
||||||
'''
|
'''
|
||||||
Langgrapg Writing TypedDict
|
Langgrapg Writing TypedDict
|
||||||
'''
|
'''
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
user_id:str
|
end_user_id: str
|
||||||
apply_id:str
|
|
||||||
group_id:str
|
|
||||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||||
memory_config: object
|
memory_config: object
|
||||||
write_result: dict
|
write_result: dict
|
||||||
data:str
|
data: str
|
||||||
|
|
||||||
class ReadState(TypedDict):
|
class ReadState(TypedDict):
|
||||||
"""
|
"""
|
||||||
@@ -28,7 +27,7 @@ class ReadState(TypedDict):
|
|||||||
messages: 消息列表,支持自动追加
|
messages: 消息列表,支持自动追加
|
||||||
loop_count: 遍历次数
|
loop_count: 遍历次数
|
||||||
search_switch: 搜索类型开关
|
search_switch: 搜索类型开关
|
||||||
group_id: 组标识
|
end_user_id: 组标识
|
||||||
config_id: 配置ID,用于过滤结果
|
config_id: 配置ID,用于过滤结果
|
||||||
data: 从content_input_node传递的内容数据
|
data: 从content_input_node传递的内容数据
|
||||||
spit_data: 从Split_The_Problem传递的分解结果
|
spit_data: 从Split_The_Problem传递的分解结果
|
||||||
@@ -39,7 +38,7 @@ class ReadState(TypedDict):
|
|||||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||||
loop_count: int
|
loop_count: int
|
||||||
search_switch: str
|
search_switch: str
|
||||||
group_id: str
|
end_user_id: str
|
||||||
config_id: str
|
config_id: str
|
||||||
data: str # 新增字段用于传递内容
|
data: str # 新增字段用于传递内容
|
||||||
spit_data: dict # 新增字段用于传递问题分解结果
|
spit_data: dict # 新增字段用于传递问题分解结果
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
# 角色
|
||||||
|
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
|
||||||
|
# 任务
|
||||||
|
根据提供的上下文信息回答用户的问题。
|
||||||
|
# 输入信息
|
||||||
|
- 历史对话:{{history}}
|
||||||
|
- 检索信息:{{retrieve_info}}
|
||||||
|
# 用户问题
|
||||||
|
{{query}}
|
||||||
|
# 回答指南
|
||||||
|
## 1. 仔细阅读检索信息
|
||||||
|
- 答案可能直接或间接地出现在检索信息中
|
||||||
|
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
|
||||||
|
- 第三人称描述的偏好、行为通常指用户本人
|
||||||
|
|
||||||
|
## 2. 判断信息相关性
|
||||||
|
**情况A:信息匹配问题**
|
||||||
|
- 直接回答,像自然对话一样
|
||||||
|
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||||
|
|
||||||
|
**情况B:信息部分相关**
|
||||||
|
- 先回答已知部分,再自然地询问更多信息
|
||||||
|
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||||
|
|
||||||
|
**情况C:信息完全不相关**
|
||||||
|
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||||
|
- 使用友好的表达:
|
||||||
|
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||||
|
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||||
|
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||||
|
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||||
|
- 避免僵硬的"信息不足,无法回答"
|
||||||
|
## 3. 回答要求
|
||||||
|
- 像人类对话一样自然流畅
|
||||||
|
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||||
|
- 不要解释推理过程或引用信息来源
|
||||||
|
- 保持友好、乐于助人的语气
|
||||||
|
- 使用与问题相同的语言回答
|
||||||
|
# 关键示例
|
||||||
|
**示例1 - 直接匹配:**
|
||||||
|
- 检索信息:"小曼会使用Python..."
|
||||||
|
- 问题:"我叫什么"
|
||||||
|
- ✓ 正确:"你叫小曼"
|
||||||
|
- ✗ 错误:"你没有告诉我你的名字"
|
||||||
|
**示例2 - 间接匹配:**
|
||||||
|
- 检索信息:"用户很喜欢吃星巴克的甜品"
|
||||||
|
- 问题:"我喜欢什么"
|
||||||
|
- ✓ 正确:"你很喜欢吃星巴克的甜品"
|
||||||
|
- ✗ 错误:"信息不足"
|
||||||
|
**示例3 - 信息不匹配(推荐做法):**
|
||||||
|
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
|
||||||
|
- 问题:"我吃过哪家面包"
|
||||||
|
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
|
||||||
|
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
|
||||||
|
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
|
||||||
|
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
|
||||||
|
# 重要提醒
|
||||||
|
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
|
||||||
|
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
|
||||||
|
- 用对话式语言表达"不知道",而非机械模板
|
||||||
|
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
{# 角色定义 #}
|
||||||
|
你是专业的问题解答专家+引导学者
|
||||||
|
|
||||||
|
{# 输入数据展示 #}
|
||||||
|
{% if data %}
|
||||||
|
## 输入数据
|
||||||
|
上下文信息:
|
||||||
|
{% for item in data.history %}
|
||||||
|
- {{ item }}
|
||||||
|
{% endfor %}
|
||||||
|
检索到的所有信息:
|
||||||
|
{% for item in data.retrieve_info %}
|
||||||
|
- {{ item }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
## User Query
|
||||||
|
{{ query }}
|
||||||
|
|
||||||
|
{# 问题回答标准 #}
|
||||||
|
## 问题回答核心标准
|
||||||
|
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
|
||||||
|
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
|
||||||
|
**情况A:信息匹配问题**
|
||||||
|
- 直接回答,像自然对话一样
|
||||||
|
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||||
|
|
||||||
|
**情况B:信息部分相关**
|
||||||
|
- 先回答已知部分,再自然地询问更多信息
|
||||||
|
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||||
|
|
||||||
|
**情况C:信息完全不相关**
|
||||||
|
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||||
|
- 使用友好的表达:
|
||||||
|
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||||
|
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||||
|
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||||
|
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||||
|
- 避免僵硬的"信息不足,无法回答"
|
||||||
|
|
||||||
|
{# 重要提醒 #}
|
||||||
|
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
|
||||||
|
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例
|
||||||
@@ -28,7 +28,7 @@ class RedisSessionStore:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
# 修改后的 save_session 方法
|
# 修改后的 save_session 方法
|
||||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
优化版本:确保写入时间不超过1秒
|
优化版本:确保写入时间不超过1秒
|
||||||
@@ -46,7 +46,7 @@ class RedisSessionStore:
|
|||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": userid,
|
"sessionid": userid,
|
||||||
"apply_id": apply_id,
|
"apply_id": apply_id,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"aimessages": aimessages,
|
"aimessages": aimessages,
|
||||||
"starttime": starttime
|
"starttime": starttime
|
||||||
@@ -67,7 +67,7 @@ class RedisSessionStore:
|
|||||||
def save_sessions_batch(self, sessions_data):
|
def save_sessions_batch(self, sessions_data):
|
||||||
"""
|
"""
|
||||||
批量写入多条会话数据,返回 session_id 列表
|
批量写入多条会话数据,返回 session_id 列表
|
||||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||||
优化版本:批量操作,大幅提升性能
|
优化版本:批量操作,大幅提升性能
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -83,7 +83,7 @@ class RedisSessionStore:
|
|||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": session.get('userid'),
|
"sessionid": session.get('userid'),
|
||||||
"apply_id": session.get('apply_id'),
|
"apply_id": session.get('apply_id'),
|
||||||
"group_id": session.get('group_id'),
|
"end_user_id": session.get('end_user_id'),
|
||||||
"messages": session.get('messages'),
|
"messages": session.get('messages'),
|
||||||
"aimessages": session.get('aimessages'),
|
"aimessages": session.get('aimessages'),
|
||||||
"starttime": starttime
|
"starttime": starttime
|
||||||
@@ -108,9 +108,9 @@ class RedisSessionStore:
|
|||||||
data = self.r.hgetall(key)
|
data = self.r.hgetall(key)
|
||||||
return data if data else None
|
return data if data else None
|
||||||
|
|
||||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||||
"""
|
"""
|
||||||
result_items = []
|
result_items = []
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ class RedisSessionStore:
|
|||||||
# 检查三个条件是否都匹配
|
# 检查三个条件是否都匹配
|
||||||
if (data.get('sessionid') == sessionid and
|
if (data.get('sessionid') == sessionid and
|
||||||
data.get('apply_id') == apply_id and
|
data.get('apply_id') == apply_id and
|
||||||
data.get('group_id') == group_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
result_items.append(data)
|
result_items.append(data)
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -172,7 +172,7 @@ class RedisSessionStore:
|
|||||||
def delete_duplicate_sessions(self):
|
def delete_duplicate_sessions(self):
|
||||||
"""
|
"""
|
||||||
删除重复会话数据,条件:
|
删除重复会话数据,条件:
|
||||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
@@ -202,12 +202,12 @@ class RedisSessionStore:
|
|||||||
# 获取五个字段的值
|
# 获取五个字段的值
|
||||||
sessionid = data.get('sessionid', '')
|
sessionid = data.get('sessionid', '')
|
||||||
user_id = data.get('id', '')
|
user_id = data.get('id', '')
|
||||||
group_id = data.get('group_id', '')
|
end_user_id = data.get('end_user_id', '')
|
||||||
messages = data.get('messages', '')
|
messages = data.get('messages', '')
|
||||||
aimessages = data.get('aimessages', '')
|
aimessages = data.get('aimessages', '')
|
||||||
|
|
||||||
# 用五元组作为唯一标识
|
# 用五元组作为唯一标识
|
||||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||||
|
|
||||||
if identifier in seen:
|
if identifier in seen:
|
||||||
# 重复,标记为待删除
|
# 重复,标记为待删除
|
||||||
@@ -248,9 +248,9 @@ class RedisSessionStore:
|
|||||||
result_items = []
|
result_items = []
|
||||||
return (result_items)
|
return (result_items)
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -276,7 +276,7 @@ class RedisSessionStore:
|
|||||||
# 检查是否符合三个条件
|
# 检查是否符合三个条件
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
if (data.get('apply_id') == apply_id and
|
||||||
data.get('group_id') == group_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
# 支持模糊匹配 sessionid 或者完全匹配
|
# 支持模糊匹配 sessionid 或者完全匹配
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
matched_items.append({
|
matched_items.append({
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class SessionService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str
|
end_user_id: str
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve conversation history from Redis.
|
Retrieve conversation history from Redis.
|
||||||
@@ -67,20 +67,20 @@ class SessionService:
|
|||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of conversation history items with Query and Answer keys
|
List of conversation history items with Query and Answer keys
|
||||||
Returns empty list if no history found or on error
|
Returns empty list if no history found or on error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||||
|
|
||||||
# Validate history structure
|
# Validate history structure
|
||||||
if not isinstance(history, list):
|
if not isinstance(history, list):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid history format for user {user_id}, "
|
f"Invalid history format for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ class SessionService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to retrieve history for user {user_id}, "
|
f"Failed to retrieve history for user {user_id}, "
|
||||||
f"apply {apply_id}, group {group_id}: {e}",
|
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
# Return empty list on error to allow execution to continue
|
# Return empty list on error to allow execution to continue
|
||||||
@@ -100,7 +100,7 @@ class SessionService:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
apply_id: str,
|
apply_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
ai_response: str
|
ai_response: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +110,7 @@ class SessionService:
|
|||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
query: User query/message
|
query: User query/message
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
ai_response: AI response/answer
|
ai_response: AI response/answer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -131,7 +131,7 @@ class SessionService:
|
|||||||
userid=user_id,
|
userid=user_id,
|
||||||
messages=query,
|
messages=query,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
aimessages=ai_response
|
aimessages=ai_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class SessionService:
|
|||||||
Duplicates are identified by matching:
|
Duplicates are identified by matching:
|
||||||
- sessionid
|
- sessionid
|
||||||
- user_id (id field)
|
- user_id (id field)
|
||||||
- group_id
|
- end_user_id
|
||||||
- messages
|
- messages
|
||||||
- aimessages
|
- aimessages
|
||||||
|
|
||||||
|
|||||||
@@ -29,20 +29,18 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
user_id: str,
|
end_user_id: str,
|
||||||
apply_id: str,
|
|
||||||
group_id: str,
|
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
ref_id: str = "wyl20251027",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User identifier
|
user_id: User identifier
|
||||||
apply_id: Application identifier
|
apply_id: Application identifier
|
||||||
group_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
ref_id: Reference ID, defaults to "wyl20251027"
|
||||||
@@ -51,14 +49,14 @@ async def write(
|
|||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
config_id = str(memory_config.config_id)
|
config_id = str(memory_config.config_id)
|
||||||
|
|
||||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||||
logger.info(f"Group ID: {group_id}")
|
logger.info(f"end_user_id ID: {end_user_id}")
|
||||||
|
|
||||||
# Construct clients from memory_config using factory pattern with db session
|
# Construct clients from memory_config using factory pattern with db session
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
@@ -83,9 +81,7 @@ async def write(
|
|||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
chunked_dialogs = await get_chunked_dialogs(
|
chunked_dialogs = await get_chunked_dialogs(
|
||||||
chunker_strategy=chunker_strategy,
|
chunker_strategy=chunker_strategy,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
|
||||||
apply_id=apply_id,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
|
|||||||
@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def get_default_docs_path() -> str:
|
def get_default_docs_path() -> str:
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
from pathlib import Path
|
||||||
|
project_root = str(Path(__file__).resolve().parents[2])
|
||||||
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
|||||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||||
|
|
||||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tags: 原始标签列表
|
tags: 原始标签列表
|
||||||
group_id: 用户组ID,用于获取配置
|
end_user_id: 用户组ID,用于获取配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
筛选后的标签列表
|
筛选后的标签列表
|
||||||
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
|||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
|
||||||
if not config_id:
|
if not config_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No memory_config_id found for group_id: {group_id}. "
|
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||||
"Please ensure the user has a valid memory configuration."
|
"Please ensure the user has a valid memory configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
|||||||
|
|
||||||
async def get_raw_tags_from_db(
|
async def get_raw_tags_from_db(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
limit: int,
|
limit: int,
|
||||||
by_user: bool = False
|
by_user: bool = False
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j连接器实例
|
connector: Neo4j连接器实例
|
||||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
limit: 返回的标签数量限制
|
limit: 返回的标签数量限制
|
||||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||||
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
|
|||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"MATCH (e:ExtractedEntity) "
|
"MATCH (e:ExtractedEntity) "
|
||||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||||
"RETURN e.name AS name, count(e) AS frequency "
|
"RETURN e.name AS name, count(e) AS frequency "
|
||||||
"ORDER BY frequency DESC "
|
"ORDER BY frequency DESC "
|
||||||
"LIMIT $limit"
|
"LIMIT $limit"
|
||||||
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
|
|||||||
# 使用项目的Neo4jConnector执行查询
|
# 使用项目的Neo4jConnector执行查询
|
||||||
results = await connector.execute_query(
|
results = await connector.execute_query(
|
||||||
query,
|
query,
|
||||||
id=group_id,
|
id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
names_to_exclude=names_to_exclude
|
names_to_exclude=names_to_exclude
|
||||||
)
|
)
|
||||||
|
|
||||||
return [(record["name"], record["frequency"]) for record in results]
|
return [(record["name"], record["frequency"]) for record in results]
|
||||||
|
|
||||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
limit: 返回的标签数量限制
|
limit: 返回的标签数量限制
|
||||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果group_id未提供或为空
|
ValueError: 如果end_user_id未提供或为空
|
||||||
"""
|
"""
|
||||||
# 验证group_id必须提供且不为空
|
# 验证end_user_id必须提供且不为空
|
||||||
if not group_id or not group_id.strip():
|
if not end_user_id or not end_user_id.strip():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"group_id is required. Please provide a valid group_id or user_id."
|
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用项目的Neo4jConnector
|
# 使用项目的Neo4jConnector
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
# 1. 从数据库获取原始排名靠前的标签
|
# 1. 从数据库获取原始排名靠前的标签
|
||||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||||
if not raw_tags_with_freq:
|
if not raw_tags_with_freq:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||||
|
|
||||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||||
|
|
||||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||||
final_tags = []
|
final_tags = []
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
|||||||
start_date = time_range.start_date if time_range else None
|
start_date = time_range.start_date if time_range else None
|
||||||
end_date = time_range.end_date if time_range else None
|
end_date = time_range.end_date if time_range else None
|
||||||
|
|
||||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||||
group_id=user_id,
|
end_user_id=user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date
|
end_date=end_date
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ import os
|
|||||||
import re
|
import re
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback: derive project root from this file location
|
# Fallback: derive project root from this file location
|
||||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
# 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
|
||||||
|
# 需要向上 5 级到达 api/ 目录
|
||||||
|
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
|
||||||
|
|
||||||
|
|
||||||
def _get_latest_prompt_log_path() -> str | None:
|
def _get_latest_prompt_log_path() -> str | None:
|
||||||
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
|
|||||||
triplet_relations_count = 0
|
triplet_relations_count = 0
|
||||||
temporal_count = 0
|
temporal_count = 0
|
||||||
|
|
||||||
# Patterns
|
# 正则表达式模式 - 匹配当前日志格式
|
||||||
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
||||||
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
|
pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
|
||||||
pat_triplet_done = re.compile(
|
pat_triplet_completed = re.compile(
|
||||||
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
|
r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
|
||||||
)
|
)
|
||||||
pat_temporal_done = re.compile(
|
pat_temporal_completed = re.compile(
|
||||||
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
|
r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
|
# 文本块数量(每个块触发一次陈述提取提示)
|
||||||
if pat_chunk_render.search(line):
|
if pat_chunk_render.search(line):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m1 = pat_triplet_start.search(line)
|
# 陈述数量(每个 Triplet Started 代表一个陈述被处理)
|
||||||
if m1:
|
if pat_triplet_started.search(line):
|
||||||
|
statements_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
|
||||||
|
m_triplet = pat_triplet_completed.search(line)
|
||||||
|
if m_triplet:
|
||||||
try:
|
try:
|
||||||
statements_count += int(m1.group(1))
|
triplet_relations_count += int(m_triplet.group(1))
|
||||||
|
triplet_entities_count += int(m_triplet.group(2))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m2 = pat_triplet_done.search(line)
|
# 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
|
||||||
if m2:
|
m_temporal = pat_temporal_completed.search(line)
|
||||||
|
if m_temporal:
|
||||||
try:
|
try:
|
||||||
triplet_relations_count += int(m2.group(1))
|
temporal_count += int(m_temporal.group(1))
|
||||||
triplet_entities_count += int(m2.group(2))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
continue
|
|
||||||
|
|
||||||
m3 = pat_temporal_done.search(line)
|
|
||||||
if m3:
|
|
||||||
try:
|
|
||||||
temporal_count += int(m3.group(1))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def get_recent_activity_stats() -> Tuple[dict, str]:
|
def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||||
"""Get aggregated stats from all prompt logs in logs/.
|
"""Get stats from the latest prompt log file only.
|
||||||
|
|
||||||
Returns (stats_dict, message).
|
Returns (stats_dict, message).
|
||||||
"""
|
"""
|
||||||
all_logs = _get_all_prompt_logs()
|
# 获取最新的日志文件
|
||||||
# Fallback to recursive search if none found in logs/
|
latest_log = _get_latest_prompt_log_path()
|
||||||
if not all_logs:
|
|
||||||
|
# 如果没有找到,尝试递归搜索
|
||||||
|
if not latest_log:
|
||||||
all_logs = _get_any_logs_recursive()
|
all_logs = _get_any_logs_recursive()
|
||||||
if not all_logs:
|
if all_logs:
|
||||||
|
latest_log = all_logs[-1] # 取最新的
|
||||||
|
|
||||||
|
if not latest_log:
|
||||||
return (
|
return (
|
||||||
{
|
{
|
||||||
"chunk_count": 0,
|
"chunk_count": 0,
|
||||||
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
|
|||||||
"未找到日志文件,请确认已运行过提取流程。",
|
"未找到日志文件,请确认已运行过提取流程。",
|
||||||
)
|
)
|
||||||
|
|
||||||
agg = {
|
# 只解析最新的日志文件
|
||||||
"chunk_count": 0,
|
stats = parse_stats_from_log(latest_log)
|
||||||
"statements_count": 0,
|
|
||||||
"triplet_entities_count": 0,
|
# 添加日志文件路径信息
|
||||||
"triplet_relations_count": 0,
|
stats["log_path"] = f"最新:{latest_log}"
|
||||||
"temporal_count": 0,
|
|
||||||
}
|
return stats, "成功读取最近一次记忆活动统计。"
|
||||||
for path in all_logs:
|
|
||||||
s = parse_stats_from_log(path)
|
|
||||||
agg["chunk_count"] += s.get("chunk_count", 0)
|
|
||||||
agg["statements_count"] += s.get("statements_count", 0)
|
|
||||||
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
|
|
||||||
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
|
|
||||||
agg["temporal_count"] += s.get("temporal_count", 0)
|
|
||||||
|
|
||||||
# Attach a summary of files combined
|
|
||||||
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
|
|
||||||
return agg, "成功汇总 logs 目录中所有提示日志。"
|
|
||||||
|
|
||||||
|
|
||||||
def _format_summary(stats: dict) -> str:
|
def _format_summary(stats: dict) -> str:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
⏬数据集下载地址:
|
|
||||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
|
||||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
|
||||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
|
||||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
|
||||||
|
|
||||||
全流程基准测试运行:
|
|
||||||
locomo:
|
|
||||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
|
||||||
LongMemEval:
|
|
||||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
|
||||||
memsciqa:
|
|
||||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
|
||||||
|
|
||||||
单独检索评估运行命令:
|
|
||||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
|
||||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
|
||||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
|
||||||
需要先在项目中修改需要检测评估的group_id。
|
|
||||||
|
|
||||||
参数及解释:
|
|
||||||
● --dataset longmemeval - 指定数据集
|
|
||||||
● --sample-size 10 - 评估10个样本
|
|
||||||
● --start-index 0 - 从第0个样本开始
|
|
||||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
|
||||||
● --search-limit 8 - 检索限制8条
|
|
||||||
● --context-char-budget 4000 - 上下文字符预算4000
|
|
||||||
● --search-type hybrid - 使用混合检索
|
|
||||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
|
||||||
● --reset-group - 运行前清空组数据
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
import math
|
|
||||||
import re
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize(text: str) -> List[str]:
|
|
||||||
"""Lowercase, strip punctuation, and split into tokens."""
|
|
||||||
text = text.lower().strip()
|
|
||||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
|
||||||
text = re.sub(r"[^\w\s]", " ", text)
|
|
||||||
tokens = [t for t in text.split() if t]
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def exact_match(pred: str, ref: str) -> float:
|
|
||||||
return float(_normalize(pred) == _normalize(ref))
|
|
||||||
|
|
||||||
|
|
||||||
def jaccard(pred: str, ref: str) -> float:
|
|
||||||
p = set(_normalize(pred))
|
|
||||||
r = set(_normalize(ref))
|
|
||||||
if not p and not r:
|
|
||||||
return 1.0
|
|
||||||
if not p or not r:
|
|
||||||
return 0.0
|
|
||||||
return len(p & r) / len(p | r)
|
|
||||||
|
|
||||||
|
|
||||||
def f1_score(pred: str, ref: str) -> float:
|
|
||||||
p_tokens = _normalize(pred)
|
|
||||||
r_tokens = _normalize(ref)
|
|
||||||
if not p_tokens and not r_tokens:
|
|
||||||
return 1.0
|
|
||||||
if not p_tokens or not r_tokens:
|
|
||||||
return 0.0
|
|
||||||
p_set = set(p_tokens)
|
|
||||||
r_set = set(r_tokens)
|
|
||||||
tp = len(p_set & r_set)
|
|
||||||
precision = tp / len(p_set) if p_set else 0.0
|
|
||||||
recall = tp / len(r_set) if r_set else 0.0
|
|
||||||
if precision + recall == 0:
|
|
||||||
return 0.0
|
|
||||||
return 2 * precision * recall / (precision + recall)
|
|
||||||
|
|
||||||
|
|
||||||
def bleu1(pred: str, ref: str) -> float:
|
|
||||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
|
||||||
p_tokens = _normalize(pred)
|
|
||||||
r_tokens = _normalize(ref)
|
|
||||||
if not p_tokens:
|
|
||||||
return 0.0
|
|
||||||
# Clipped count
|
|
||||||
r_counts: Dict[str, int] = {}
|
|
||||||
for t in r_tokens:
|
|
||||||
r_counts[t] = r_counts.get(t, 0) + 1
|
|
||||||
clipped = 0
|
|
||||||
p_counts: Dict[str, int] = {}
|
|
||||||
for t in p_tokens:
|
|
||||||
p_counts[t] = p_counts.get(t, 0) + 1
|
|
||||||
for t, c in p_counts.items():
|
|
||||||
clipped += min(c, r_counts.get(t, 0))
|
|
||||||
precision = clipped / max(len(p_tokens), 1)
|
|
||||||
# Brevity penalty
|
|
||||||
ref_len = len(r_tokens)
|
|
||||||
pred_len = len(p_tokens)
|
|
||||||
if pred_len > ref_len or pred_len == 0:
|
|
||||||
bp = 1.0
|
|
||||||
else:
|
|
||||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
|
||||||
return bp * precision
|
|
||||||
|
|
||||||
|
|
||||||
def percentile(values: List[float], p: float) -> float:
|
|
||||||
if not values:
|
|
||||||
return 0.0
|
|
||||||
vals = sorted(values)
|
|
||||||
k = (len(vals) - 1) * p
|
|
||||||
f = math.floor(k)
|
|
||||||
c = math.ceil(k)
|
|
||||||
if f == c:
|
|
||||||
return vals[int(k)]
|
|
||||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
|
||||||
|
|
||||||
|
|
||||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
|
||||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
|
||||||
if not latencies_ms:
|
|
||||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
|
||||||
p25 = percentile(latencies_ms, 0.25)
|
|
||||||
p50 = percentile(latencies_ms, 0.50)
|
|
||||||
p75 = percentile(latencies_ms, 0.75)
|
|
||||||
p95 = percentile(latencies_ms, 0.95)
|
|
||||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
|
||||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
|
||||||
|
|
||||||
|
|
||||||
def avg_context_tokens(contexts: List[str]) -> float:
|
|
||||||
if not contexts:
|
|
||||||
return 0.0
|
|
||||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
"""
|
|
||||||
Dialogue search queries for evaluation purposes.
|
|
||||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
|
||||||
Placed in evaluation directory to avoid circular imports with src modules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Entity search queries
|
|
||||||
SEARCH_ENTITIES_BY_NAME = """
|
|
||||||
MATCH (e:Entity)
|
|
||||||
WHERE e.name = $name
|
|
||||||
RETURN e
|
|
||||||
"""
|
|
||||||
|
|
||||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
|
||||||
MATCH (e:Entity)
|
|
||||||
WHERE e.name CONTAINS $name
|
|
||||||
RETURN e
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Chunk search queries
|
|
||||||
SEARCH_CHUNKS_BY_CONTENT = """
|
|
||||||
MATCH (c:Chunk)
|
|
||||||
WHERE c.content CONTAINS $content
|
|
||||||
RETURN c
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Dialogue search queries
|
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
|
||||||
MATCH (d:Dialogue)
|
|
||||||
WHERE d.dialog_id = $dialog_id
|
|
||||||
RETURN d
|
|
||||||
"""
|
|
||||||
|
|
||||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
|
||||||
MATCH (d:Dialogue)
|
|
||||||
WHERE d.content CONTAINS $q
|
|
||||||
RETURN d
|
|
||||||
"""
|
|
||||||
|
|
||||||
DIALOGUE_EMBEDDING_SEARCH = """
|
|
||||||
WITH $embedding AS q
|
|
||||||
MATCH (d:Dialogue)
|
|
||||||
WHERE d.dialog_embedding IS NOT NULL
|
|
||||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
|
||||||
WITH d, q, d.dialog_embedding AS v
|
|
||||||
WITH d,
|
|
||||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
|
||||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
|
||||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
|
||||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
|
||||||
WHERE score > $threshold
|
|
||||||
RETURN d.id AS dialog_id,
|
|
||||||
d.group_id AS group_id,
|
|
||||||
d.content AS content,
|
|
||||||
d.created_at AS created_at,
|
|
||||||
d.expired_at AS expired_at,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
@@ -1,341 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
|
||||||
from app.core.memory.models.message_models import (
|
|
||||||
ConversationContext,
|
|
||||||
ConversationMessage,
|
|
||||||
DialogData,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用新的模块化架构
|
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
|
||||||
ExtractionOrchestrator,
|
|
||||||
)
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
|
||||||
DialogueChunker,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.config.definitions import (
|
|
||||||
SELECTED_CHUNKER_STRATEGY,
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.db import get_db_context
|
|
||||||
|
|
||||||
# Import from database module
|
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
|
|
||||||
# Cypher queries for evaluation
|
|
||||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
|
||||||
|
|
||||||
|
|
||||||
async def ingest_contexts_via_full_pipeline(
|
|
||||||
contexts: List[str],
|
|
||||||
group_id: str,
|
|
||||||
chunker_strategy: str | None = None,
|
|
||||||
embedding_name: str | None = None,
|
|
||||||
save_chunk_output: bool = False,
|
|
||||||
save_chunk_output_path: str | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
|
||||||
|
|
||||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
|
||||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
|
||||||
Args:
|
|
||||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
|
||||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
|
||||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
|
||||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
|
||||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
|
||||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
|
||||||
Returns:
|
|
||||||
True if data saved successfully, False otherwise.
|
|
||||||
"""
|
|
||||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
|
||||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
|
||||||
|
|
||||||
# Initialize llm client with graceful fallback
|
|
||||||
llm_client = None
|
|
||||||
llm_available = True
|
|
||||||
try:
|
|
||||||
from app.core.memory.utils.config import definitions as config_defs
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
|
||||||
llm_available = False
|
|
||||||
|
|
||||||
# Step A: Build DialogData list from contexts with robust parsing
|
|
||||||
chunker = DialogueChunker(chunker_strategy)
|
|
||||||
dialog_data_list: List[DialogData] = []
|
|
||||||
|
|
||||||
for idx, ctx in enumerate(contexts):
|
|
||||||
messages: List[ConversationMessage] = []
|
|
||||||
|
|
||||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
|
||||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
|
||||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
|
||||||
|
|
||||||
if matches:
|
|
||||||
for m in matches:
|
|
||||||
raw_role = m.group(1).strip()
|
|
||||||
content = m.group(2).strip()
|
|
||||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
|
||||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
|
||||||
else:
|
|
||||||
# Fallback: line-by-line parsing
|
|
||||||
for raw in ctx.split("\n"):
|
|
||||||
line = raw.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
|
||||||
if m:
|
|
||||||
role = m.group(1).strip()
|
|
||||||
msg = m.group(2).strip()
|
|
||||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
|
||||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
|
||||||
else:
|
|
||||||
# Final fallback: treat as user message
|
|
||||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
|
||||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
|
||||||
|
|
||||||
context_model = ConversationContext(msgs=messages)
|
|
||||||
dialog = DialogData(
|
|
||||||
context=context_model,
|
|
||||||
ref_id=f"pipeline_item_{idx}",
|
|
||||||
group_id=group_id,
|
|
||||||
user_id="default_user",
|
|
||||||
apply_id="default_application",
|
|
||||||
)
|
|
||||||
# Generate chunks
|
|
||||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
|
||||||
dialog_data_list.append(dialog)
|
|
||||||
|
|
||||||
if not dialog_data_list:
|
|
||||||
print("No dialogs to process for ingestion.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Optionally save chunking outputs for debugging
|
|
||||||
if save_chunk_output:
|
|
||||||
try:
|
|
||||||
def _serialize_datetime(obj):
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
settings.ensure_memory_output_dir()
|
|
||||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
|
||||||
out_path = save_chunk_output_path or default_path
|
|
||||||
|
|
||||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
|
||||||
with open(out_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
|
||||||
print(f"Saved chunking results to: {out_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save chunking results: {e}")
|
|
||||||
|
|
||||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
|
||||||
if not llm_available:
|
|
||||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 初始化 embedder 客户端
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
try:
|
|
||||||
with get_db_context() as db:
|
|
||||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
|
||||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
|
||||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
|
||||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
|
||||||
return False
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# 初始化并运行 ExtractionOrchestrator
|
|
||||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
|
||||||
config = get_pipeline_config()
|
|
||||||
|
|
||||||
orchestrator = ExtractionOrchestrator(
|
|
||||||
llm_client=llm_client,
|
|
||||||
embedder_client=embedder_client,
|
|
||||||
connector=connector,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
|
||||||
# 保存原始的 _assign_extracted_data 方法
|
|
||||||
original_assign = orchestrator._assign_extracted_data
|
|
||||||
|
|
||||||
def clean_temporal_value(value):
|
|
||||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
if isinstance(value, str):
|
|
||||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
|
||||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
|
||||||
return None
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def patched_assign_extracted_data(*args, **kwargs):
|
|
||||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
|
||||||
result = await original_assign(*args, **kwargs)
|
|
||||||
|
|
||||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
|
||||||
for dialog in result:
|
|
||||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
|
||||||
for chunk in dialog.chunks:
|
|
||||||
if hasattr(chunk, 'statements') and chunk.statements:
|
|
||||||
for statement in chunk.statements:
|
|
||||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
|
||||||
tv = statement.temporal_validity
|
|
||||||
# 清理 valid_at 和 invalid_at
|
|
||||||
if hasattr(tv, 'valid_at'):
|
|
||||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
|
||||||
if hasattr(tv, 'invalid_at'):
|
|
||||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 替换方法
|
|
||||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
|
||||||
|
|
||||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
|
||||||
original_create = orchestrator._create_nodes_and_edges
|
|
||||||
|
|
||||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
|
||||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
|
||||||
# 最后一次清理,确保万无一失
|
|
||||||
for dialog in dialog_data_list_arg:
|
|
||||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
|
||||||
for chunk in dialog.chunks:
|
|
||||||
if hasattr(chunk, 'statements') and chunk.statements:
|
|
||||||
for statement in chunk.statements:
|
|
||||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
|
||||||
tv = statement.temporal_validity
|
|
||||||
if hasattr(tv, 'valid_at'):
|
|
||||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
|
||||||
if hasattr(tv, 'invalid_at'):
|
|
||||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
|
||||||
|
|
||||||
return await original_create(dialog_data_list_arg)
|
|
||||||
|
|
||||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
|
||||||
|
|
||||||
# 运行完整的提取流水线
|
|
||||||
# orchestrator.run 返回 7 个元素的元组
|
|
||||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
|
||||||
(
|
|
||||||
dialogue_nodes,
|
|
||||||
chunk_nodes,
|
|
||||||
statement_nodes,
|
|
||||||
entity_nodes,
|
|
||||||
statement_chunk_edges,
|
|
||||||
statement_entity_edges,
|
|
||||||
entity_entity_edges,
|
|
||||||
) = result
|
|
||||||
|
|
||||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
|
||||||
|
|
||||||
# Step G: 生成记忆摘要
|
|
||||||
print("[Ingestion] Generating memory summaries...")
|
|
||||||
try:
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
|
||||||
memory_summary_generation,
|
|
||||||
)
|
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
|
||||||
|
|
||||||
summaries = await memory_summary_generation(
|
|
||||||
chunked_dialogs=dialog_data_list,
|
|
||||||
llm_client=llm_client,
|
|
||||||
embedder_client=embedder_client
|
|
||||||
)
|
|
||||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
|
||||||
summaries = []
|
|
||||||
|
|
||||||
# Step H: Save to Neo4j
|
|
||||||
try:
|
|
||||||
success = await save_dialog_and_statements_to_neo4j(
|
|
||||||
dialogue_nodes=dialogue_nodes,
|
|
||||||
chunk_nodes=chunk_nodes,
|
|
||||||
statement_nodes=statement_nodes,
|
|
||||||
entity_nodes=entity_nodes,
|
|
||||||
entity_edges=entity_entity_edges,
|
|
||||||
statement_chunk_edges=statement_chunk_edges,
|
|
||||||
statement_entity_edges=statement_entity_edges,
|
|
||||||
connector=connector
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save memory summaries separately
|
|
||||||
if summaries:
|
|
||||||
try:
|
|
||||||
await add_memory_summary_nodes(summaries, connector)
|
|
||||||
await add_memory_summary_statement_edges(summaries, connector)
|
|
||||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to save summary nodes: {e}")
|
|
||||||
|
|
||||||
await connector.close()
|
|
||||||
if success:
|
|
||||||
print("Successfully saved extracted data to Neo4j!")
|
|
||||||
else:
|
|
||||||
print("Failed to save data to Neo4j")
|
|
||||||
return success
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save data to Neo4j: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_context_processing(args):
|
|
||||||
"""Handle context-based processing from command line arguments."""
|
|
||||||
contexts = []
|
|
||||||
|
|
||||||
if args.contexts:
|
|
||||||
contexts.extend(args.contexts)
|
|
||||||
|
|
||||||
if args.context_file:
|
|
||||||
try:
|
|
||||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
|
||||||
contexts.extend(line.strip() for line in f if line.strip())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading context file: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not contexts:
|
|
||||||
print("No contexts provided for processing.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await main_from_contexts(contexts, args.context_group_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
|
||||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
|
||||||
print("=== Running pipeline from provided contexts ===")
|
|
||||||
|
|
||||||
success = await ingest_contexts_via_full_pipeline(
|
|
||||||
contexts=contexts,
|
|
||||||
group_id=group_id,
|
|
||||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
|
||||||
embedding_name=SELECTED_EMBEDDING_ID,
|
|
||||||
save_chunk_output=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("Successfully processed and saved contexts to Neo4j!")
|
|
||||||
else:
|
|
||||||
print("Failed to process contexts.")
|
|
||||||
|
|
||||||
return success
|
|
||||||
@@ -1,575 +0,0 @@
|
|||||||
"""
|
|
||||||
LoCoMo Benchmark Script
|
|
||||||
|
|
||||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
|
||||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
|
||||||
in a clean, maintainable way.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except ImportError:
|
|
||||||
def load_dotenv():
|
|
||||||
pass
|
|
||||||
|
|
||||||
from app.core.memory.evaluation.common.metrics import (
|
|
||||||
avg_context_tokens,
|
|
||||||
bleu1,
|
|
||||||
f1_score,
|
|
||||||
jaccard,
|
|
||||||
latency_stats,
|
|
||||||
)
|
|
||||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
|
||||||
get_category_name,
|
|
||||||
locomo_f1_score,
|
|
||||||
locomo_multi_f1,
|
|
||||||
)
|
|
||||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
|
||||||
extract_conversations,
|
|
||||||
ingest_conversations_if_needed,
|
|
||||||
load_locomo_data,
|
|
||||||
resolve_temporal_references,
|
|
||||||
retrieve_relevant_information,
|
|
||||||
select_and_format_information,
|
|
||||||
)
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.memory.utils.definitions import (
|
|
||||||
PROJECT_ROOT,
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
SELECTED_GROUP_ID,
|
|
||||||
SELECTED_LLM_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
|
|
||||||
async def run_locomo_benchmark(
|
|
||||||
sample_size: int = 20,
|
|
||||||
group_id: Optional[str] = None,
|
|
||||||
search_type: str = "hybrid",
|
|
||||||
search_limit: int = 12,
|
|
||||||
context_char_budget: int = 8000,
|
|
||||||
reset_group: bool = False,
|
|
||||||
skip_ingest: bool = False,
|
|
||||||
output_dir: Optional[str] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Run LoCoMo benchmark evaluation.
|
|
||||||
|
|
||||||
This function orchestrates the complete evaluation pipeline:
|
|
||||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
|
||||||
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
|
|
||||||
3. For each question:
|
|
||||||
- Retrieve relevant information
|
|
||||||
- Generate answer using LLM
|
|
||||||
- Calculate metrics
|
|
||||||
4. Aggregate results and save to file
|
|
||||||
|
|
||||||
Note: By default, only the first conversation is ingested into the database,
|
|
||||||
and only QA pairs from that conversation are evaluated. This ensures that
|
|
||||||
all questions have corresponding memory in the database for retrieval.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
|
||||||
group_id: Database group ID for retrieval (uses default if None)
|
|
||||||
search_type: "keyword", "embedding", or "hybrid"
|
|
||||||
search_limit: Max documents to retrieve per query
|
|
||||||
context_char_budget: Max characters for context
|
|
||||||
reset_group: Whether to clear and re-ingest data (not implemented)
|
|
||||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
|
||||||
output_dir: Directory to save results (uses default if None)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with evaluation results including metrics, timing, and samples
|
|
||||||
"""
|
|
||||||
# Use default group_id if not provided
|
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
|
||||||
|
|
||||||
# Determine data path
|
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
# Fallback to current directory
|
|
||||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print("📊 Configuration:")
|
|
||||||
print(f" Sample size: {sample_size}")
|
|
||||||
print(f" Group ID: {group_id}")
|
|
||||||
print(f" Search type: {search_type}")
|
|
||||||
print(f" Search limit: {search_limit}")
|
|
||||||
print(f" Context budget: {context_char_budget} chars")
|
|
||||||
print(f" Data path: {data_path}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Step 1: Load LoCoMo data
|
|
||||||
print("📂 Loading LoCoMo dataset...")
|
|
||||||
try:
|
|
||||||
# Only load QA pairs from the first conversation (index 0)
|
|
||||||
# since we only ingest the first conversation into the database
|
|
||||||
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
|
|
||||||
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to load data: {e}")
|
|
||||||
return {
|
|
||||||
"error": f"Data loading failed: {e}",
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Step 2: Extract conversations and ingest if needed
|
|
||||||
if skip_ingest:
|
|
||||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
|
||||||
print(f" Group ID: {group_id}\n")
|
|
||||||
else:
|
|
||||||
print("💾 Checking database ingestion...")
|
|
||||||
try:
|
|
||||||
conversations = extract_conversations(data_path, max_dialogues=1)
|
|
||||||
print(f"📝 Extracted {len(conversations)} conversations")
|
|
||||||
|
|
||||||
# Always ingest for now (ingestion check not implemented)
|
|
||||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
|
||||||
success = await ingest_conversations_if_needed(
|
|
||||||
conversations=conversations,
|
|
||||||
group_id=group_id,
|
|
||||||
reset=reset_group
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("✅ Ingestion completed successfully\n")
|
|
||||||
else:
|
|
||||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Ingestion failed: {e}")
|
|
||||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
|
||||||
|
|
||||||
# Step 3: Initialize clients
|
|
||||||
print("🔧 Initializing clients...")
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# Initialize LLM client with database context
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
|
||||||
|
|
||||||
# Initialize embedder
|
|
||||||
with get_db_context() as db:
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
|
||||||
embedder = OpenAIEmbedderClient(
|
|
||||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
|
||||||
)
|
|
||||||
print("✅ Clients initialized\n")
|
|
||||||
|
|
||||||
# Step 4: Process questions
|
|
||||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Tracking variables
|
|
||||||
latencies_search: List[float] = []
|
|
||||||
latencies_llm: List[float] = []
|
|
||||||
context_counts: List[int] = []
|
|
||||||
context_chars: List[int] = []
|
|
||||||
context_tokens: List[int] = []
|
|
||||||
|
|
||||||
# Metric lists
|
|
||||||
f1_scores: List[float] = []
|
|
||||||
bleu1_scores: List[float] = []
|
|
||||||
jaccard_scores: List[float] = []
|
|
||||||
locomo_f1_scores: List[float] = []
|
|
||||||
|
|
||||||
# Per-category tracking
|
|
||||||
category_counts: Dict[str, int] = {}
|
|
||||||
category_f1: Dict[str, List[float]] = {}
|
|
||||||
category_bleu1: Dict[str, List[float]] = {}
|
|
||||||
category_jaccard: Dict[str, List[float]] = {}
|
|
||||||
category_locomo_f1: Dict[str, List[float]] = {}
|
|
||||||
|
|
||||||
# Detailed samples
|
|
||||||
samples: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
# Fixed anchor date for temporal resolution
|
|
||||||
anchor_date = datetime(2023, 5, 8)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for idx, item in enumerate(qa_items, 1):
|
|
||||||
question = item.get("question", "")
|
|
||||||
ground_truth = item.get("answer", "")
|
|
||||||
category = get_category_name(item)
|
|
||||||
|
|
||||||
# Ensure ground truth is a string
|
|
||||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
|
||||||
|
|
||||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
|
||||||
print(f"❓ Question: {question}")
|
|
||||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
|
||||||
|
|
||||||
# Step 4a: Retrieve relevant information
|
|
||||||
t_search_start = time.time()
|
|
||||||
try:
|
|
||||||
retrieved_info = await retrieve_relevant_information(
|
|
||||||
question=question,
|
|
||||||
group_id=group_id,
|
|
||||||
search_type=search_type,
|
|
||||||
search_limit=search_limit,
|
|
||||||
connector=connector,
|
|
||||||
embedder=embedder
|
|
||||||
)
|
|
||||||
t_search_end = time.time()
|
|
||||||
search_latency = (t_search_end - t_search_start) * 1000
|
|
||||||
latencies_search.append(search_latency)
|
|
||||||
|
|
||||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Retrieval failed: {e}")
|
|
||||||
retrieved_info = []
|
|
||||||
search_latency = 0.0
|
|
||||||
latencies_search.append(search_latency)
|
|
||||||
|
|
||||||
# Step 4b: Select and format context
|
|
||||||
context_text = select_and_format_information(
|
|
||||||
retrieved_info=retrieved_info,
|
|
||||||
question=question,
|
|
||||||
max_chars=context_char_budget
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve temporal references
|
|
||||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
|
||||||
|
|
||||||
# Add reference date to context
|
|
||||||
if context_text:
|
|
||||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
|
||||||
else:
|
|
||||||
context_text = "No relevant context found."
|
|
||||||
|
|
||||||
# Track context statistics
|
|
||||||
context_counts.append(len(retrieved_info))
|
|
||||||
context_chars.append(len(context_text))
|
|
||||||
context_tokens.append(len(context_text.split()))
|
|
||||||
|
|
||||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
|
||||||
|
|
||||||
# Step 4c: Generate answer with LLM
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"You are a precise QA assistant. Answer following these rules:\n"
|
|
||||||
"1) Extract the EXACT information mentioned in the context\n"
|
|
||||||
"2) For time questions: calculate actual dates from relative times\n"
|
|
||||||
"3) Return ONLY the answer text in simplest form\n"
|
|
||||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
|
||||||
"5) If no clear answer found, respond with 'Unknown'"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
t_llm_start = time.time()
|
|
||||||
try:
|
|
||||||
response = await llm_client.chat(messages=messages)
|
|
||||||
t_llm_end = time.time()
|
|
||||||
llm_latency = (t_llm_end - t_llm_start) * 1000
|
|
||||||
latencies_llm.append(llm_latency)
|
|
||||||
|
|
||||||
# Extract prediction from response
|
|
||||||
if hasattr(response, 'content'):
|
|
||||||
prediction = response.content.strip()
|
|
||||||
elif isinstance(response, dict):
|
|
||||||
prediction = response["choices"][0]["message"]["content"].strip()
|
|
||||||
else:
|
|
||||||
prediction = "Unknown"
|
|
||||||
|
|
||||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ LLM failed: {e}")
|
|
||||||
prediction = "Unknown"
|
|
||||||
llm_latency = 0.0
|
|
||||||
latencies_llm.append(llm_latency)
|
|
||||||
|
|
||||||
# Step 4d: Calculate metrics
|
|
||||||
f1_val = f1_score(prediction, ground_truth_str)
|
|
||||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
|
||||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
|
||||||
|
|
||||||
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
|
|
||||||
if item.get("category") == 1:
|
|
||||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
|
||||||
else:
|
|
||||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
|
||||||
|
|
||||||
# Accumulate metrics
|
|
||||||
f1_scores.append(f1_val)
|
|
||||||
bleu1_scores.append(bleu1_val)
|
|
||||||
jaccard_scores.append(jaccard_val)
|
|
||||||
locomo_f1_scores.append(locomo_f1_val)
|
|
||||||
|
|
||||||
# Track by category
|
|
||||||
category_counts[category] = category_counts.get(category, 0) + 1
|
|
||||||
category_f1.setdefault(category, []).append(f1_val)
|
|
||||||
category_bleu1.setdefault(category, []).append(bleu1_val)
|
|
||||||
category_jaccard.setdefault(category, []).append(jaccard_val)
|
|
||||||
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
|
|
||||||
|
|
||||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
|
||||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Save sample details
|
|
||||||
samples.append({
|
|
||||||
"question": question,
|
|
||||||
"ground_truth": ground_truth_str,
|
|
||||||
"prediction": prediction,
|
|
||||||
"category": category,
|
|
||||||
"metrics": {
|
|
||||||
"f1": f1_val,
|
|
||||||
"bleu1": bleu1_val,
|
|
||||||
"jaccard": jaccard_val,
|
|
||||||
"locomo_f1": locomo_f1_val
|
|
||||||
},
|
|
||||||
"retrieval": {
|
|
||||||
"num_docs": len(retrieved_info),
|
|
||||||
"context_length": len(context_text)
|
|
||||||
},
|
|
||||||
"timing": {
|
|
||||||
"search_ms": search_latency,
|
|
||||||
"llm_ms": llm_latency
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Close connector
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
# Step 5: Aggregate results
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("📊 Aggregating Results")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Overall metrics
|
|
||||||
overall_metrics = {
|
|
||||||
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
|
|
||||||
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
|
|
||||||
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
|
|
||||||
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Per-category metrics
|
|
||||||
by_category: Dict[str, Dict[str, Any]] = {}
|
|
||||||
for cat in category_counts:
|
|
||||||
f1_list = category_f1.get(cat, [])
|
|
||||||
b1_list = category_bleu1.get(cat, [])
|
|
||||||
j_list = category_jaccard.get(cat, [])
|
|
||||||
lf_list = category_locomo_f1.get(cat, [])
|
|
||||||
|
|
||||||
by_category[cat] = {
|
|
||||||
"count": category_counts[cat],
|
|
||||||
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
|
|
||||||
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
|
|
||||||
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
|
|
||||||
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Latency statistics
|
|
||||||
latency = {
|
|
||||||
"search": latency_stats(latencies_search),
|
|
||||||
"llm": latency_stats(latencies_llm)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Context statistics
|
|
||||||
context_stats = {
|
|
||||||
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
|
|
||||||
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
|
|
||||||
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build result dictionary
|
|
||||||
result = {
|
|
||||||
"dataset": "locomo",
|
|
||||||
"sample_size": len(qa_items),
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"params": {
|
|
||||||
"group_id": group_id,
|
|
||||||
"search_type": search_type,
|
|
||||||
"search_limit": search_limit,
|
|
||||||
"context_char_budget": context_char_budget,
|
|
||||||
"llm_id": SELECTED_LLM_ID,
|
|
||||||
"embedding_id": SELECTED_EMBEDDING_ID
|
|
||||||
},
|
|
||||||
"overall_metrics": overall_metrics,
|
|
||||||
"by_category": by_category,
|
|
||||||
"latency": latency,
|
|
||||||
"context_stats": context_stats,
|
|
||||||
"samples": samples
|
|
||||||
}
|
|
||||||
|
|
||||||
# Step 6: Save results
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
"results"
|
|
||||||
)
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate timestamped filename
|
|
||||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"✅ Results saved to: {output_path}\n")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to save results: {e}")
|
|
||||||
print("📊 Printing results to console instead:\n")
|
|
||||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""
|
|
||||||
Parse command-line arguments and run benchmark.
|
|
||||||
|
|
||||||
This function provides a CLI interface for running LoCoMo benchmarks
|
|
||||||
with configurable parameters.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run LoCoMo benchmark evaluation",
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_size",
|
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="Number of QA pairs to evaluate"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--group_id",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Database group ID for retrieval (uses default if not specified)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--search_type",
|
|
||||||
type=str,
|
|
||||||
default="hybrid",
|
|
||||||
choices=["keyword", "embedding", "hybrid"],
|
|
||||||
help="Search strategy to use"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--search_limit",
|
|
||||||
type=int,
|
|
||||||
default=12,
|
|
||||||
help="Maximum number of documents to retrieve per query"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context_char_budget",
|
|
||||||
type=int,
|
|
||||||
default=8000,
|
|
||||||
help="Maximum characters for context"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--reset_group",
|
|
||||||
action="store_true",
|
|
||||||
help="Clear and re-ingest data (not implemented)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_ingest",
|
|
||||||
action="store_true",
|
|
||||||
help="Skip data ingestion and use existing data in Neo4j"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory to save results (uses default if not specified)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Load environment variables
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
result = asyncio.run(run_locomo_benchmark(
|
|
||||||
sample_size=args.sample_size,
|
|
||||||
group_id=args.group_id,
|
|
||||||
search_type=args.search_type,
|
|
||||||
search_limit=args.search_limit,
|
|
||||||
context_char_budget=args.context_char_budget,
|
|
||||||
reset_group=args.reset_group,
|
|
||||||
skip_ingest=args.skip_ingest,
|
|
||||||
output_dir=args.output_dir
|
|
||||||
))
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
|
|
||||||
# Check if there was an error
|
|
||||||
if 'error' in result:
|
|
||||||
print("❌ Benchmark Failed!")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Error: {result['error']}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("🎉 Benchmark Complete!")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print("📊 Final Results:")
|
|
||||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
|
||||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
|
||||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
|
||||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
|
||||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
|
||||||
|
|
||||||
if result.get('context_stats'):
|
|
||||||
print("\n📈 Context Statistics:")
|
|
||||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
|
||||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
|
||||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
|
||||||
|
|
||||||
if result.get('latency'):
|
|
||||||
print("\n⏱️ Latency Statistics:")
|
|
||||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
|
||||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
|
||||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
|
||||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
|
||||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
|
||||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
|
||||||
|
|
||||||
if result.get('by_category'):
|
|
||||||
print("\n📂 Results by Category:")
|
|
||||||
for cat, metrics in result['by_category'].items():
|
|
||||||
print(f" {cat}:")
|
|
||||||
print(f" Count: {metrics['count']}")
|
|
||||||
print(f" F1: {metrics['f1']:.3f}")
|
|
||||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
|
||||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
|
||||||
|
|
||||||
print(f"\n{'='*60}\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""
|
|
||||||
LoCoMo-specific metric calculations.
|
|
||||||
|
|
||||||
This module provides clean, simplified implementations of metrics used for
|
|
||||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text: str) -> str:
|
|
||||||
"""
|
|
||||||
Normalize text for LoCoMo evaluation.
|
|
||||||
|
|
||||||
Normalization steps:
|
|
||||||
- Convert to lowercase
|
|
||||||
- Remove commas
|
|
||||||
- Remove stop words (a, an, the, and)
|
|
||||||
- Remove punctuation
|
|
||||||
- Normalize whitespace
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text to normalize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized text string with consistent formatting
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> normalize_text("The cat, and the dog")
|
|
||||||
'cat dog'
|
|
||||||
>>> normalize_text("Hello, World!")
|
|
||||||
'hello world'
|
|
||||||
"""
|
|
||||||
# Ensure input is a string
|
|
||||||
text = str(text) if text is not None else ""
|
|
||||||
|
|
||||||
# Convert to lowercase
|
|
||||||
text = text.lower()
|
|
||||||
|
|
||||||
# Remove commas
|
|
||||||
text = re.sub(r"[\,]", " ", text)
|
|
||||||
|
|
||||||
# Remove stop words
|
|
||||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
|
||||||
|
|
||||||
# Remove punctuation (keep only word characters and whitespace)
|
|
||||||
text = re.sub(r"[^\w\s]", " ", text)
|
|
||||||
|
|
||||||
# Normalize whitespace (collapse multiple spaces to single space)
|
|
||||||
text = " ".join(text.split())
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
|
||||||
"""
|
|
||||||
Calculate LoCoMo F1 score for single-answer questions.
|
|
||||||
|
|
||||||
Uses token-level precision and recall based on normalized text.
|
|
||||||
Treats tokens as sets (no duplicate counting).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prediction: Model's predicted answer
|
|
||||||
ground_truth: Correct answer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
F1 score between 0.0 and 1.0
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> locomo_f1_score("Paris", "Paris")
|
|
||||||
1.0
|
|
||||||
>>> locomo_f1_score("The cat", "cat")
|
|
||||||
1.0
|
|
||||||
>>> locomo_f1_score("dog", "cat")
|
|
||||||
0.0
|
|
||||||
"""
|
|
||||||
# Ensure inputs are strings
|
|
||||||
pred_str = str(prediction) if prediction is not None else ""
|
|
||||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
|
||||||
|
|
||||||
# Normalize and tokenize
|
|
||||||
pred_tokens = normalize_text(pred_str).split()
|
|
||||||
truth_tokens = normalize_text(truth_str).split()
|
|
||||||
|
|
||||||
# Handle empty cases
|
|
||||||
if not pred_tokens or not truth_tokens:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# Convert to sets for comparison
|
|
||||||
pred_set = set(pred_tokens)
|
|
||||||
truth_set = set(truth_tokens)
|
|
||||||
|
|
||||||
# Calculate true positives (intersection)
|
|
||||||
true_positives = len(pred_set & truth_set)
|
|
||||||
|
|
||||||
# Calculate precision and recall
|
|
||||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
|
||||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
|
||||||
|
|
||||||
# Calculate F1 score
|
|
||||||
if precision + recall == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
f1 = 2 * precision * recall / (precision + recall)
|
|
||||||
return f1
|
|
||||||
|
|
||||||
|
|
||||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
|
||||||
"""
|
|
||||||
Calculate LoCoMo F1 score for multi-answer questions.
|
|
||||||
|
|
||||||
Handles comma-separated answers by:
|
|
||||||
1. Splitting both prediction and ground truth by commas
|
|
||||||
2. For each ground truth answer, finding the best matching prediction
|
|
||||||
3. Averaging the F1 scores across all ground truth answers
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
|
||||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
|
||||||
1.0
|
|
||||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
|
||||||
0.5
|
|
||||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
|
||||||
0.5
|
|
||||||
"""
|
|
||||||
# Ensure inputs are strings
|
|
||||||
pred_str = str(prediction) if prediction is not None else ""
|
|
||||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
|
||||||
|
|
||||||
# Split by commas and strip whitespace
|
|
||||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
|
||||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
|
||||||
|
|
||||||
# Handle empty cases
|
|
||||||
if not predictions or not ground_truths:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# For each ground truth, find the best matching prediction
|
|
||||||
f1_scores = []
|
|
||||||
for gt in ground_truths:
|
|
||||||
# Calculate F1 with each prediction and take the maximum
|
|
||||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
|
||||||
f1_scores.append(best_f1)
|
|
||||||
|
|
||||||
# Return average F1 across all ground truths
|
|
||||||
return sum(f1_scores) / len(f1_scores)
|
|
||||||
|
|
||||||
|
|
||||||
def get_category_name(item: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Extract and normalize category name from QA item.
|
|
||||||
|
|
||||||
Handles both numeric categories (1-4) and string categories with various formats.
|
|
||||||
Supports multiple field names: "cat", "category", "type".
|
|
||||||
|
|
||||||
Category mapping:
|
|
||||||
- 1 or "multi-hop" -> "Multi-Hop"
|
|
||||||
- 2 or "temporal" -> "Temporal"
|
|
||||||
- 3 or "open domain" -> "Open Domain"
|
|
||||||
- 4 or "single-hop" -> "Single-Hop"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item: QA item dictionary containing category information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Standardized category name or "unknown" if not found
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> get_category_name({"category": 1})
|
|
||||||
'Multi-Hop'
|
|
||||||
>>> get_category_name({"cat": "temporal"})
|
|
||||||
'Temporal'
|
|
||||||
>>> get_category_name({"type": "Single-Hop"})
|
|
||||||
'Single-Hop'
|
|
||||||
"""
|
|
||||||
# Numeric category mapping
|
|
||||||
CATEGORY_MAP = {
|
|
||||||
1: "Multi-Hop",
|
|
||||||
2: "Temporal",
|
|
||||||
3: "Open Domain",
|
|
||||||
4: "Single-Hop",
|
|
||||||
}
|
|
||||||
|
|
||||||
# String category aliases (case-insensitive)
|
|
||||||
TYPE_ALIASES = {
|
|
||||||
"single-hop": "Single-Hop",
|
|
||||||
"singlehop": "Single-Hop",
|
|
||||||
"single hop": "Single-Hop",
|
|
||||||
"multi-hop": "Multi-Hop",
|
|
||||||
"multihop": "Multi-Hop",
|
|
||||||
"multi hop": "Multi-Hop",
|
|
||||||
"open domain": "Open Domain",
|
|
||||||
"opendomain": "Open Domain",
|
|
||||||
"temporal": "Temporal",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Try "cat" field first (string category)
|
|
||||||
cat = item.get("cat")
|
|
||||||
if isinstance(cat, str) and cat.strip():
|
|
||||||
name = cat.strip()
|
|
||||||
lower = name.lower()
|
|
||||||
return TYPE_ALIASES.get(lower, name)
|
|
||||||
|
|
||||||
# Try "category" field (can be int or string)
|
|
||||||
cat_num = item.get("category")
|
|
||||||
if isinstance(cat_num, int):
|
|
||||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
|
||||||
elif isinstance(cat_num, str) and cat_num.strip():
|
|
||||||
lower = cat_num.strip().lower()
|
|
||||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
|
||||||
|
|
||||||
# Try "type" field as fallback
|
|
||||||
cat_type = item.get("type")
|
|
||||||
if isinstance(cat_type, str) and cat_type.strip():
|
|
||||||
lower = cat_type.strip().lower()
|
|
||||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
@@ -1,810 +0,0 @@
|
|||||||
# file name: check_neo4j_connection_fixed.py
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# 1
|
|
||||||
# 添加项目根目录到路径
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.dirname(current_dir)
|
|
||||||
if project_root not in sys.path:
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
|
||||||
src_dir = os.path.join(project_root, "src")
|
|
||||||
if src_dir not in sys.path:
|
|
||||||
sys.path.insert(0, src_dir)
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
|
||||||
def _loc_normalize(text: str) -> str:
|
|
||||||
text = str(text) if text is not None else ""
|
|
||||||
text = text.lower()
|
|
||||||
text = re.sub(r"[\,]", " ", text)
|
|
||||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
|
||||||
text = re.sub(r"[^\w\s]", " ", text)
|
|
||||||
text = " ".join(text.split())
|
|
||||||
return text
|
|
||||||
|
|
||||||
# 尝试从 metrics.py 导入基础指标
|
|
||||||
try:
|
|
||||||
from common.metrics import bleu1, f1_score, jaccard
|
|
||||||
print("✅ 从 metrics.py 导入基础指标成功")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
|
||||||
# 回退到本地实现
|
|
||||||
def f1_score(pred: str, ref: str) -> float:
|
|
||||||
pred_str = str(pred) if pred is not None else ""
|
|
||||||
ref_str = str(ref) if ref is not None else ""
|
|
||||||
|
|
||||||
p_tokens = _loc_normalize(pred_str).split()
|
|
||||||
r_tokens = _loc_normalize(ref_str).split()
|
|
||||||
if not p_tokens and not r_tokens:
|
|
||||||
return 1.0
|
|
||||||
if not p_tokens or not r_tokens:
|
|
||||||
return 0.0
|
|
||||||
p_set = set(p_tokens)
|
|
||||||
r_set = set(r_tokens)
|
|
||||||
tp = len(p_set & r_set)
|
|
||||||
precision = tp / len(p_set) if p_set else 0.0
|
|
||||||
recall = tp / len(r_set) if r_set else 0.0
|
|
||||||
if precision + recall == 0:
|
|
||||||
return 0.0
|
|
||||||
return 2 * precision * recall / (precision + recall)
|
|
||||||
|
|
||||||
def bleu1(pred: str, ref: str) -> float:
|
|
||||||
pred_str = str(pred) if pred is not None else ""
|
|
||||||
ref_str = str(ref) if ref is not None else ""
|
|
||||||
|
|
||||||
p_tokens = _loc_normalize(pred_str).split()
|
|
||||||
r_tokens = _loc_normalize(ref_str).split()
|
|
||||||
if not p_tokens:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
r_counts = {}
|
|
||||||
for t in r_tokens:
|
|
||||||
r_counts[t] = r_counts.get(t, 0) + 1
|
|
||||||
|
|
||||||
clipped = 0
|
|
||||||
p_counts = {}
|
|
||||||
for t in p_tokens:
|
|
||||||
p_counts[t] = p_counts.get(t, 0) + 1
|
|
||||||
|
|
||||||
for t, c in p_counts.items():
|
|
||||||
clipped += min(c, r_counts.get(t, 0))
|
|
||||||
|
|
||||||
precision = clipped / max(len(p_tokens), 1)
|
|
||||||
ref_len = len(r_tokens)
|
|
||||||
pred_len = len(p_tokens)
|
|
||||||
|
|
||||||
if pred_len > ref_len or pred_len == 0:
|
|
||||||
bp = 1.0
|
|
||||||
else:
|
|
||||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
|
||||||
|
|
||||||
return bp * precision
|
|
||||||
|
|
||||||
def jaccard(pred: str, ref: str) -> float:
|
|
||||||
pred_str = str(pred) if pred is not None else ""
|
|
||||||
ref_str = str(ref) if ref is not None else ""
|
|
||||||
|
|
||||||
p = set(_loc_normalize(pred_str).split())
|
|
||||||
r = set(_loc_normalize(ref_str).split())
|
|
||||||
if not p and not r:
|
|
||||||
return 1.0
|
|
||||||
if not p or not r:
|
|
||||||
return 0.0
|
|
||||||
return len(p & r) / len(p | r)
|
|
||||||
|
|
||||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
|
||||||
try:
|
|
||||||
# 添加 evaluation 目录路径
|
|
||||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
|
||||||
if evaluation_dir not in sys.path:
|
|
||||||
sys.path.insert(0, evaluation_dir)
|
|
||||||
|
|
||||||
# 尝试从不同位置导入
|
|
||||||
try:
|
|
||||||
from locomo.qwen_search_eval import (
|
|
||||||
_resolve_relative_times,
|
|
||||||
loc_f1_score,
|
|
||||||
loc_multi_f1,
|
|
||||||
)
|
|
||||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
|
||||||
except ImportError:
|
|
||||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
|
||||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
|
||||||
# 回退到本地实现 LoCoMo 特定函数
|
|
||||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
|
||||||
t = str(text) if text is not None else ""
|
|
||||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
|
|
||||||
def _ago_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor - timedelta(days=n)).date().isoformat()
|
|
||||||
def _in_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor + timedelta(days=n)).date().isoformat()
|
|
||||||
|
|
||||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
return t
|
|
||||||
|
|
||||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
|
||||||
p_tokens = _loc_normalize(prediction).split()
|
|
||||||
g_tokens = _loc_normalize(ground_truth).split()
|
|
||||||
if not p_tokens or not g_tokens:
|
|
||||||
return 0.0
|
|
||||||
p = set(p_tokens)
|
|
||||||
g = set(g_tokens)
|
|
||||||
tp = len(p & g)
|
|
||||||
precision = tp / len(p) if p else 0.0
|
|
||||||
recall = tp / len(g) if g else 0.0
|
|
||||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
|
||||||
|
|
||||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
|
||||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
|
||||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
|
||||||
if not predictions or not ground_truths:
|
|
||||||
return 0.0
|
|
||||||
def _f1(a: str, b: str) -> float:
|
|
||||||
return loc_f1_score(a, b)
|
|
||||||
vals = []
|
|
||||||
for gt in ground_truths:
|
|
||||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
|
||||||
return sum(vals) / len(vals)
|
|
||||||
|
|
||||||
|
|
||||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
|
||||||
"""基于问题关键词智能选择上下文"""
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 提取问题关键词(只保留有意义的词)
|
|
||||||
question_lower = question.lower()
|
|
||||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
|
||||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
|
||||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
|
||||||
|
|
||||||
print(f"🔍 问题关键词: {question_words}")
|
|
||||||
|
|
||||||
# 给每个上下文打分
|
|
||||||
scored_contexts = []
|
|
||||||
for i, context in enumerate(contexts):
|
|
||||||
context_lower = context.lower()
|
|
||||||
score = 0
|
|
||||||
|
|
||||||
# 关键词匹配得分
|
|
||||||
keyword_matches = 0
|
|
||||||
for word in question_words:
|
|
||||||
if word in context_lower:
|
|
||||||
keyword_matches += 1
|
|
||||||
# 关键词出现次数越多,得分越高
|
|
||||||
score += context_lower.count(word) * 2
|
|
||||||
|
|
||||||
# 上下文长度得分(适中的长度更好)
|
|
||||||
context_len = len(context)
|
|
||||||
if 100 < context_len < 2000: # 理想长度范围
|
|
||||||
score += 5
|
|
||||||
elif context_len >= 2000: # 太长可能包含无关信息
|
|
||||||
score += 2
|
|
||||||
|
|
||||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
|
||||||
if i < 3:
|
|
||||||
score += 3
|
|
||||||
|
|
||||||
scored_contexts.append((score, context, keyword_matches))
|
|
||||||
|
|
||||||
# 按得分排序
|
|
||||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
# 选择高得分的上下文,直到达到字符限制
|
|
||||||
selected = []
|
|
||||||
total_chars = 0
|
|
||||||
selected_count = 0
|
|
||||||
|
|
||||||
print("📊 上下文相关性分析:")
|
|
||||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
|
||||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
|
||||||
|
|
||||||
for score, context, matches in scored_contexts:
|
|
||||||
if total_chars + len(context) <= max_chars:
|
|
||||||
selected.append(context)
|
|
||||||
total_chars += len(context)
|
|
||||||
selected_count += 1
|
|
||||||
else:
|
|
||||||
# 如果这个上下文得分很高但放不下,尝试截取
|
|
||||||
if score > 10 and total_chars < max_chars - 500:
|
|
||||||
remaining = max_chars - total_chars
|
|
||||||
# 找到包含关键词的部分
|
|
||||||
lines = context.split('\n')
|
|
||||||
relevant_lines = []
|
|
||||||
current_chars = 0
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
line_lower = line.lower()
|
|
||||||
line_relevance = any(word in line_lower for word in question_words)
|
|
||||||
|
|
||||||
if line_relevance and current_chars < remaining - 100:
|
|
||||||
relevant_lines.append(line)
|
|
||||||
current_chars += len(line)
|
|
||||||
|
|
||||||
if relevant_lines:
|
|
||||||
truncated = '\n'.join(relevant_lines)
|
|
||||||
if len(truncated) > 100: # 确保有足够内容
|
|
||||||
selected.append(truncated + "\n[相关内容截断...]")
|
|
||||||
total_chars += len(truncated)
|
|
||||||
selected_count += 1
|
|
||||||
break # 不再尝试添加更多上下文
|
|
||||||
|
|
||||||
result = "\n\n".join(selected)
|
|
||||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
|
||||||
"""根据问题复杂度和进度动态调整检索参数"""
|
|
||||||
|
|
||||||
# 分析问题复杂度
|
|
||||||
word_count = len(question.split())
|
|
||||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
|
||||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
|
||||||
|
|
||||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
|
||||||
progress_factor = question_index / total_questions
|
|
||||||
|
|
||||||
base_limit = 12
|
|
||||||
if has_temporal and has_multi_hop:
|
|
||||||
base_limit = 20
|
|
||||||
elif word_count > 8:
|
|
||||||
base_limit = 16
|
|
||||||
|
|
||||||
# 随着测试进行,逐渐收紧检索范围
|
|
||||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
|
||||||
|
|
||||||
# 动态调整最大字符数
|
|
||||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"limit": adjusted_limit,
|
|
||||||
"max_chars": int(max_chars)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedEvaluationMonitor:
|
|
||||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
|
||||||
self.question_count = 0
|
|
||||||
self.reset_interval = reset_interval
|
|
||||||
self.performance_threshold = performance_threshold
|
|
||||||
self.consecutive_low_scores = 0
|
|
||||||
self.performance_history = []
|
|
||||||
self.recent_f1_scores = []
|
|
||||||
|
|
||||||
def should_reset_connections(self, current_f1=None):
|
|
||||||
"""基于计数和性能双重判断"""
|
|
||||||
# 定期重置
|
|
||||||
if self.question_count % self.reset_interval == 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 性能驱动的重置
|
|
||||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
|
||||||
self.consecutive_low_scores += 1
|
|
||||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
|
||||||
print("🚨 连续低分,触发紧急重置")
|
|
||||||
self.consecutive_low_scores = 0
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
self.consecutive_low_scores = 0
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
|
||||||
"""记录性能指标,检测衰减"""
|
|
||||||
self.performance_history.append({
|
|
||||||
'index': question_index,
|
|
||||||
'metrics': metrics,
|
|
||||||
'context_length': context_length,
|
|
||||||
'retrieved_docs': retrieved_docs,
|
|
||||||
'timestamp': time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
# 记录最近的F1分数
|
|
||||||
self.recent_f1_scores.append(metrics['f1'])
|
|
||||||
if len(self.recent_f1_scores) > 5:
|
|
||||||
self.recent_f1_scores.pop(0)
|
|
||||||
|
|
||||||
def get_recent_performance(self):
|
|
||||||
"""获取近期平均性能"""
|
|
||||||
if not self.recent_f1_scores:
|
|
||||||
return 0.5
|
|
||||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
|
||||||
|
|
||||||
def get_performance_trend(self):
|
|
||||||
"""分析性能趋势"""
|
|
||||||
if len(self.performance_history) < 2:
|
|
||||||
return "stable"
|
|
||||||
|
|
||||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
|
||||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
|
||||||
|
|
||||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
|
||||||
return "stable"
|
|
||||||
|
|
||||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
|
||||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
|
||||||
|
|
||||||
if recent_avg < earlier_avg * 0.8:
|
|
||||||
return "degrading"
|
|
||||||
elif recent_avg > earlier_avg * 1.1:
|
|
||||||
return "improving"
|
|
||||||
else:
|
|
||||||
return "stable"
|
|
||||||
|
|
||||||
|
|
||||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
|
||||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
|
||||||
|
|
||||||
# 基础参数
|
|
||||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
|
||||||
|
|
||||||
# 性能自适应调整
|
|
||||||
if recent_performance < 0.5: # 近期表现差
|
|
||||||
# 增加检索范围,尝试获取更多上下文
|
|
||||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
|
||||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
|
||||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
|
||||||
|
|
||||||
elif recent_performance > 0.8: # 近期表现好
|
|
||||||
# 收紧检索,提高精度
|
|
||||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
|
||||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
|
||||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
|
||||||
|
|
||||||
# 中间阶段特殊处理
|
|
||||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
|
||||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
|
||||||
print("🎯 中间阶段:使用更精确的检索策略")
|
|
||||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
|
||||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
|
||||||
|
|
||||||
return base_params
|
|
||||||
|
|
||||||
|
|
||||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
|
||||||
"""考虑问题序列位置的智能选择"""
|
|
||||||
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 在序列中间阶段使用更严格的筛选
|
|
||||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
|
||||||
|
|
||||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
|
||||||
print("🎯 中间阶段:使用严格上下文筛选")
|
|
||||||
|
|
||||||
# 提取问题关键词
|
|
||||||
question_lower = question.lower()
|
|
||||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
|
||||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
|
||||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
|
||||||
|
|
||||||
# 只保留高度相关的上下文
|
|
||||||
filtered_contexts = []
|
|
||||||
for context in contexts:
|
|
||||||
context_lower = context.lower()
|
|
||||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
|
||||||
|
|
||||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
|
||||||
if any(char.isdigit() for char in context):
|
|
||||||
relevance_score += 2
|
|
||||||
|
|
||||||
# 提高阈值:只有得分>=3的上下文才保留
|
|
||||||
if relevance_score >= 3:
|
|
||||||
filtered_contexts.append(context)
|
|
||||||
else:
|
|
||||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
|
||||||
|
|
||||||
contexts = filtered_contexts
|
|
||||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
|
||||||
|
|
||||||
# 使用原有的智能选择逻辑
|
|
||||||
return smart_context_selection(contexts, question, max_chars)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_enhanced_evaluation():
|
|
||||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except Exception:
|
|
||||||
def load_dotenv():
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.memory.utils.config.definitions import (
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
SELECTED_LLM_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
# 加载数据
|
|
||||||
# 获取项目根目录
|
|
||||||
current_file = os.path.abspath(__file__)
|
|
||||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
|
||||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
|
||||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
raw = json.load(f)
|
|
||||||
|
|
||||||
qa_items = []
|
|
||||||
if isinstance(raw, list):
|
|
||||||
for entry in raw:
|
|
||||||
qa_items.extend(entry.get("qa", []))
|
|
||||||
else:
|
|
||||||
qa_items.extend(raw.get("qa", []))
|
|
||||||
|
|
||||||
items = qa_items[:20] # 测试多少个问题
|
|
||||||
|
|
||||||
# 初始化增强监控器
|
|
||||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
|
||||||
|
|
||||||
# 初始化embedder
|
|
||||||
with get_db_context() as db:
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
|
||||||
embedder = OpenAIEmbedderClient(
|
|
||||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化连接器
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# 初始化结果字典
|
|
||||||
results = {
|
|
||||||
"questions": [],
|
|
||||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
|
||||||
"category_metrics": {},
|
|
||||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
|
||||||
"performance_trend": "stable",
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"enhanced_strategy": True
|
|
||||||
}
|
|
||||||
|
|
||||||
total_f1 = 0.0
|
|
||||||
total_bleu1 = 0.0
|
|
||||||
total_jaccard = 0.0
|
|
||||||
total_loc_f1 = 0.0
|
|
||||||
total_context_length = 0
|
|
||||||
total_retrieved_docs = 0
|
|
||||||
category_stats = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i, item in enumerate(items):
|
|
||||||
monitor.question_count += 1
|
|
||||||
|
|
||||||
# 获取近期性能用于重置判断
|
|
||||||
recent_performance = monitor.get_recent_performance()
|
|
||||||
|
|
||||||
# 增强的重置判断
|
|
||||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
|
||||||
if should_reset and i > 0:
|
|
||||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
|
||||||
await connector.close()
|
|
||||||
connector = Neo4jConnector() # 创建新连接
|
|
||||||
print("✅ 连接重置完成")
|
|
||||||
|
|
||||||
q = item.get("question", "")
|
|
||||||
ref = item.get("answer", "")
|
|
||||||
ref_str = str(ref) if ref is not None else ""
|
|
||||||
|
|
||||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
|
||||||
print(f"✅ 真实答案: {ref_str}")
|
|
||||||
|
|
||||||
# 分类别统计
|
|
||||||
category = "Unknown"
|
|
||||||
if item.get("category") == 1:
|
|
||||||
category = "Multi-Hop"
|
|
||||||
elif item.get("category") == 2:
|
|
||||||
category = "Temporal"
|
|
||||||
elif item.get("category") == 3:
|
|
||||||
category = "Open Domain"
|
|
||||||
elif item.get("category") == 4:
|
|
||||||
category = "Single-Hop"
|
|
||||||
|
|
||||||
# 增强的检索参数
|
|
||||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
|
||||||
search_limit = search_params["limit"]
|
|
||||||
max_chars = search_params["max_chars"]
|
|
||||||
|
|
||||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
|
||||||
|
|
||||||
# 使用项目标准的混合检索方法
|
|
||||||
t0 = time.time()
|
|
||||||
contexts_all = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用统一的搜索服务
|
|
||||||
from app.core.memory.storage_services.search import run_hybrid_search
|
|
||||||
|
|
||||||
print("🔀 使用混合搜索服务...")
|
|
||||||
|
|
||||||
search_results = await run_hybrid_search(
|
|
||||||
query_text=q,
|
|
||||||
search_type="hybrid",
|
|
||||||
group_id="locomo_sk",
|
|
||||||
limit=20,
|
|
||||||
include=["statements", "chunks", "entities", "summaries"],
|
|
||||||
alpha=0.6, # BM25权重
|
|
||||||
embedding_id=SELECTED_EMBEDDING_ID
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
|
||||||
|
|
||||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
|
||||||
for c in chunks:
|
|
||||||
content = str(c.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
for sm in summaries:
|
|
||||||
summary_text = str(sm.get("summary", "")).strip()
|
|
||||||
if summary_text:
|
|
||||||
contexts_all.append(summary_text)
|
|
||||||
|
|
||||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 检索失败: {e}")
|
|
||||||
contexts_all = []
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
search_time = (t1 - t0) * 1000
|
|
||||||
|
|
||||||
# 增强的上下文选择
|
|
||||||
context_text = ""
|
|
||||||
if contexts_all:
|
|
||||||
# 使用增强的上下文选择
|
|
||||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
|
||||||
|
|
||||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
|
||||||
if len(context_text) > max_chars:
|
|
||||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
|
||||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
|
||||||
|
|
||||||
# 时间解析
|
|
||||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
|
||||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
|
||||||
|
|
||||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
|
||||||
|
|
||||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
|
||||||
|
|
||||||
# 显示不同上下文的预览(不只是第一条)
|
|
||||||
print("🔍 上下文预览:")
|
|
||||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
|
||||||
preview = context[:150].replace('\n', ' ')
|
|
||||||
print(f" 上下文{j+1}: {preview}...")
|
|
||||||
|
|
||||||
# 🔍 调试:检查答案是否在上下文中
|
|
||||||
if ref_str and ref_str.strip():
|
|
||||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
|
||||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("❌ 没有检索到有效上下文")
|
|
||||||
context_text = "No relevant context found."
|
|
||||||
|
|
||||||
# LLM 回答
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": (
|
|
||||||
"You are a precise QA assistant. Answer following these rules:\n"
|
|
||||||
"1) Extract the EXACT information mentioned in the context\n"
|
|
||||||
"2) For time questions: calculate actual dates from relative times\n"
|
|
||||||
"3) Return ONLY the answer text in simplest form\n"
|
|
||||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
|
||||||
"5) If no clear answer found, respond with 'Unknown'"
|
|
||||||
)},
|
|
||||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
|
||||||
]
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
try:
|
|
||||||
# 使用异步调用
|
|
||||||
resp = await llm.chat(messages=messages)
|
|
||||||
# 兼容不同的响应格式
|
|
||||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ LLM 生成失败: {e}")
|
|
||||||
pred = "Unknown"
|
|
||||||
t3 = time.time()
|
|
||||||
llm_time = (t3 - t2) * 1000
|
|
||||||
|
|
||||||
# 计算指标 - 使用导入的指标函数
|
|
||||||
f1_val = f1_score(pred, ref_str)
|
|
||||||
bleu1_val = bleu1(pred, ref_str)
|
|
||||||
jaccard_val = jaccard(pred, ref_str)
|
|
||||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
|
||||||
|
|
||||||
print(f"🤖 LLM 回答: {pred}")
|
|
||||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
|
||||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
|
||||||
|
|
||||||
# 更新统计
|
|
||||||
total_f1 += f1_val
|
|
||||||
total_bleu1 += bleu1_val
|
|
||||||
total_jaccard += jaccard_val
|
|
||||||
total_loc_f1 += loc_f1_val
|
|
||||||
total_context_length += len(context_text)
|
|
||||||
total_retrieved_docs += len(contexts_all)
|
|
||||||
|
|
||||||
if category not in category_stats:
|
|
||||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
|
||||||
|
|
||||||
category_stats[category]["count"] += 1
|
|
||||||
category_stats[category]["f1_sum"] += f1_val
|
|
||||||
category_stats[category]["b1_sum"] += bleu1_val
|
|
||||||
category_stats[category]["j_sum"] += jaccard_val
|
|
||||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
|
||||||
|
|
||||||
# 记录性能指标
|
|
||||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
|
||||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
|
||||||
|
|
||||||
# 保存结果
|
|
||||||
question_result = {
|
|
||||||
"question": q,
|
|
||||||
"ground_truth": ref_str,
|
|
||||||
"prediction": pred,
|
|
||||||
"category": category,
|
|
||||||
"metrics": metrics,
|
|
||||||
"retrieval": {
|
|
||||||
"retrieved_documents": len(contexts_all),
|
|
||||||
"context_length": len(context_text),
|
|
||||||
"search_limit": search_limit,
|
|
||||||
"max_chars": max_chars,
|
|
||||||
"recent_performance": recent_performance
|
|
||||||
},
|
|
||||||
"timing": {
|
|
||||||
"search_ms": search_time,
|
|
||||||
"llm_ms": llm_time
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results["questions"].append(question_result)
|
|
||||||
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 评估过程中发生错误: {e}")
|
|
||||||
# 即使出错,也返回已有的结果
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
# 计算总体指标
|
|
||||||
n = len(items)
|
|
||||||
if n > 0:
|
|
||||||
results["overall_metrics"] = {
|
|
||||||
"f1": total_f1 / n,
|
|
||||||
"b1": total_bleu1 / n,
|
|
||||||
"j": total_jaccard / n,
|
|
||||||
"loc_f1": total_loc_f1 / n
|
|
||||||
}
|
|
||||||
|
|
||||||
for category, stats in category_stats.items():
|
|
||||||
count = stats["count"]
|
|
||||||
results["category_metrics"][category] = {
|
|
||||||
"count": count,
|
|
||||||
"f1": stats["f1_sum"] / count,
|
|
||||||
"bleu1": stats["b1_sum"] / count,
|
|
||||||
"jaccard": stats["j_sum"] / count,
|
|
||||||
"loc_f1": stats["loc_f1_sum"] / count
|
|
||||||
}
|
|
||||||
|
|
||||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
|
||||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
|
||||||
|
|
||||||
# 分析性能趋势
|
|
||||||
results["performance_trend"] = monitor.get_performance_trend()
|
|
||||||
results["reset_interval"] = monitor.reset_interval
|
|
||||||
results["total_questions_processed"] = monitor.question_count
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
|
||||||
print("📋 增强特性:")
|
|
||||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
|
||||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
|
||||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
|
||||||
print(" - 连续性能监控:实时检测性能衰减")
|
|
||||||
|
|
||||||
result = asyncio.run(run_enhanced_evaluation())
|
|
||||||
|
|
||||||
print("\n📊 最终评估结果:")
|
|
||||||
print("总体指标:")
|
|
||||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
|
||||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
|
||||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
|
||||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
|
||||||
|
|
||||||
print("\n分类别指标:")
|
|
||||||
for category, metrics in result['category_metrics'].items():
|
|
||||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
|
||||||
|
|
||||||
print("\n检索统计:")
|
|
||||||
stats = result['retrieval_stats']
|
|
||||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
|
||||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
|
||||||
|
|
||||||
print(f"\n性能趋势: {result['performance_trend']}")
|
|
||||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
|
||||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
|
||||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
|
||||||
|
|
||||||
|
|
||||||
# 保存结果到指定目录
|
|
||||||
# 使用代码文件所在目录的绝对路径
|
|
||||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
output_dir = os.path.join(current_file_dir, "results")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"\n详细结果已保存到: {output_file}")
|
|
||||||
@@ -1,626 +0,0 @@
|
|||||||
"""
|
|
||||||
LoCoMo Utilities Module
|
|
||||||
|
|
||||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
|
||||||
- Data loading from JSON files
|
|
||||||
- Conversation extraction for ingestion
|
|
||||||
- Temporal reference resolution
|
|
||||||
- Context selection and formatting
|
|
||||||
- Retrieval wrapper functions
|
|
||||||
- Ingestion wrapper functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
|
||||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def load_locomo_data(
|
|
||||||
data_path: str,
|
|
||||||
sample_size: int,
|
|
||||||
conversation_index: int = 0
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Load LoCoMo dataset from JSON file.
|
|
||||||
|
|
||||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
|
||||||
object contains a "qa" list of question-answer pairs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_path: Path to locomo10.json file
|
|
||||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
|
||||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of QA item dictionaries, each containing:
|
|
||||||
- question: str
|
|
||||||
- answer: str
|
|
||||||
- category: int (1-4)
|
|
||||||
- evidence: List[str]
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If data_path does not exist
|
|
||||||
json.JSONDecodeError: If file is not valid JSON
|
|
||||||
IndexError: If conversation_index is out of range
|
|
||||||
"""
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
|
||||||
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
raw = json.load(f)
|
|
||||||
|
|
||||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
|
||||||
qa_items: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
if isinstance(raw, list):
|
|
||||||
# Only load QA pairs from the specified conversation
|
|
||||||
if conversation_index < len(raw):
|
|
||||||
entry = raw[conversation_index]
|
|
||||||
if isinstance(entry, dict) and "qa" in entry:
|
|
||||||
qa_items.extend(entry.get("qa", []))
|
|
||||||
else:
|
|
||||||
raise IndexError(
|
|
||||||
f"Conversation index {conversation_index} out of range. "
|
|
||||||
f"Dataset has {len(raw)} conversations."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fallback: single object with qa list
|
|
||||||
if conversation_index == 0:
|
|
||||||
qa_items.extend(raw.get("qa", []))
|
|
||||||
else:
|
|
||||||
raise IndexError(
|
|
||||||
f"Conversation index {conversation_index} out of range. "
|
|
||||||
f"Dataset has only 1 conversation."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return only the requested sample size
|
|
||||||
return qa_items[:sample_size]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
|
||||||
"""
|
|
||||||
Extract conversation texts from LoCoMo data for ingestion.
|
|
||||||
|
|
||||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
|
||||||
so they can be ingested into the memory system. Each conversation is formatted
|
|
||||||
as a multi-line string with "role: message" format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_path: Path to locomo10.json file
|
|
||||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation strings formatted for ingestion.
|
|
||||||
Each string contains multiple lines in format "role: message"
|
|
||||||
|
|
||||||
Example output:
|
|
||||||
[
|
|
||||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
|
||||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
|
||||||
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
raw = json.load(f)
|
|
||||||
|
|
||||||
# Ensure we have a list of entries
|
|
||||||
entries = raw if isinstance(raw, list) else [raw]
|
|
||||||
|
|
||||||
contents: List[str] = []
|
|
||||||
|
|
||||||
for i, entry in enumerate(entries[:max_dialogues]):
|
|
||||||
if not isinstance(entry, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
conv = entry.get("conversation", {})
|
|
||||||
|
|
||||||
if not isinstance(conv, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
lines: List[str] = []
|
|
||||||
|
|
||||||
# Collect all session_* messages
|
|
||||||
for key, val in sorted(conv.items()):
|
|
||||||
if isinstance(val, list) and key.startswith("session_"):
|
|
||||||
for msg in val:
|
|
||||||
if not isinstance(msg, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
role = msg.get("speaker") or "User"
|
|
||||||
text = msg.get("text") or ""
|
|
||||||
text = str(text).strip()
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
continue
|
|
||||||
|
|
||||||
lines.append(f"{role}: {text}")
|
|
||||||
|
|
||||||
if lines:
|
|
||||||
contents.append("\n".join(lines))
|
|
||||||
|
|
||||||
return contents
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
|
||||||
"""
|
|
||||||
Resolve relative temporal references to absolute dates.
|
|
||||||
|
|
||||||
This function converts relative time expressions (like "today", "yesterday",
|
|
||||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
|
||||||
|
|
||||||
Supported patterns:
|
|
||||||
- today, yesterday, tomorrow
|
|
||||||
- X days ago, in X days
|
|
||||||
- last week, next week
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text containing temporal references
|
|
||||||
anchor_date: Reference date for resolution (datetime object)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> anchor = datetime(2023, 5, 8)
|
|
||||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
|
||||||
"I saw him 2023-05-07"
|
|
||||||
"""
|
|
||||||
# Ensure input is a string
|
|
||||||
t = str(text) if text is not None else ""
|
|
||||||
|
|
||||||
# today / yesterday / tomorrow
|
|
||||||
t = re.sub(
|
|
||||||
r"\btoday\b",
|
|
||||||
anchor_date.date().isoformat(),
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
t = re.sub(
|
|
||||||
r"\byesterday\b",
|
|
||||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
t = re.sub(
|
|
||||||
r"\btomorrow\b",
|
|
||||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
|
|
||||||
# X days ago
|
|
||||||
def _ago_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
|
||||||
|
|
||||||
# in X days
|
|
||||||
def _in_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
|
||||||
|
|
||||||
t = re.sub(
|
|
||||||
r"\b(\d+)\s+days?\s+ago\b",
|
|
||||||
_ago_repl,
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
t = re.sub(
|
|
||||||
r"\bin\s+(\d+)\s+days?\b",
|
|
||||||
_in_repl,
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
|
|
||||||
# last week / next week (approximate as 7 days)
|
|
||||||
t = re.sub(
|
|
||||||
r"\blast\s+week\b",
|
|
||||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
t = re.sub(
|
|
||||||
r"\bnext\s+week\b",
|
|
||||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
|
||||||
t,
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def select_and_format_information(
|
|
||||||
retrieved_info: List[str],
|
|
||||||
question: str,
|
|
||||||
max_chars: int = 8000
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
|
||||||
|
|
||||||
This function scores each piece of retrieved information based on keyword matching
|
|
||||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
|
||||||
|
|
||||||
Scoring criteria:
|
|
||||||
- Keyword matches (higher weight for multiple occurrences)
|
|
||||||
- Context length (moderate length preferred)
|
|
||||||
- Position (earlier contexts get bonus points)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
|
||||||
question: Question being answered
|
|
||||||
max_chars: Maximum total characters to include in final prompt
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string combining the most relevant information for LLM prompt.
|
|
||||||
Contexts are separated by double newlines.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
|
||||||
>>> question = "Where did Alice go?"
|
|
||||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
|
||||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
|
||||||
"""
|
|
||||||
if not retrieved_info:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Extract question keywords (filter out stop words and short words)
|
|
||||||
question_lower = question.lower()
|
|
||||||
stop_words = {
|
|
||||||
'what', 'when', 'where', 'who', 'why', 'how',
|
|
||||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
|
||||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
|
||||||
}
|
|
||||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
|
||||||
question_words = {
|
|
||||||
word for word in question_words
|
|
||||||
if word not in stop_words and len(word) > 2
|
|
||||||
}
|
|
||||||
|
|
||||||
# Score each context
|
|
||||||
scored_contexts = []
|
|
||||||
for i, context in enumerate(retrieved_info):
|
|
||||||
context_lower = context.lower()
|
|
||||||
score = 0
|
|
||||||
|
|
||||||
# Keyword matching score
|
|
||||||
keyword_matches = 0
|
|
||||||
for word in question_words:
|
|
||||||
if word in context_lower:
|
|
||||||
keyword_matches += 1
|
|
||||||
# Multiple occurrences increase score
|
|
||||||
score += context_lower.count(word) * 2
|
|
||||||
|
|
||||||
# Length score (prefer moderate length)
|
|
||||||
context_len = len(context)
|
|
||||||
if 100 < context_len < 2000:
|
|
||||||
score += 5
|
|
||||||
elif context_len >= 2000:
|
|
||||||
score += 2
|
|
||||||
|
|
||||||
# Position bonus (earlier contexts often more relevant)
|
|
||||||
if i < 3:
|
|
||||||
score += 3
|
|
||||||
|
|
||||||
scored_contexts.append((score, context, keyword_matches))
|
|
||||||
|
|
||||||
# Sort by score (descending)
|
|
||||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
# Select contexts up to character limit
|
|
||||||
selected = []
|
|
||||||
total_chars = 0
|
|
||||||
|
|
||||||
for score, context, matches in scored_contexts:
|
|
||||||
if total_chars + len(context) <= max_chars:
|
|
||||||
selected.append(context)
|
|
||||||
total_chars += len(context)
|
|
||||||
else:
|
|
||||||
# Try to include high-scoring context by truncating
|
|
||||||
if score > 10 and total_chars < max_chars - 500:
|
|
||||||
remaining = max_chars - total_chars
|
|
||||||
# Find lines with keywords
|
|
||||||
lines = context.split('\n')
|
|
||||||
relevant_lines = []
|
|
||||||
current_chars = 0
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
line_lower = line.lower()
|
|
||||||
line_relevance = any(word in line_lower for word in question_words)
|
|
||||||
|
|
||||||
if line_relevance and current_chars < remaining - 100:
|
|
||||||
relevant_lines.append(line)
|
|
||||||
current_chars += len(line)
|
|
||||||
|
|
||||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
|
||||||
truncated = '\n'.join(relevant_lines)
|
|
||||||
selected.append(truncated + "\n[Content truncated...]")
|
|
||||||
total_chars += len(truncated)
|
|
||||||
break
|
|
||||||
|
|
||||||
return "\n\n".join(selected)
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_relevant_information(
|
|
||||||
question: str,
|
|
||||||
group_id: str,
|
|
||||||
search_type: str,
|
|
||||||
search_limit: int,
|
|
||||||
connector: Any,
|
|
||||||
embedder: Any
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Retrieve relevant information from memory graph for a question.
|
|
||||||
|
|
||||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
|
||||||
returns relevant chunks, statements, and entity information that might help
|
|
||||||
answer the question.
|
|
||||||
|
|
||||||
The function supports three search types:
|
|
||||||
- "keyword": Full-text search using Cypher queries
|
|
||||||
- "embedding": Vector similarity search using embeddings
|
|
||||||
- "hybrid": Combination of keyword and embedding search with reranking
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: Question to search for
|
|
||||||
group_id: Database group ID (identifies which conversation memory to search)
|
|
||||||
search_type: "keyword", "embedding", or "hybrid"
|
|
||||||
search_limit: Max memory pieces to retrieve
|
|
||||||
connector: Neo4j connector instance
|
|
||||||
embedder: Embedder client instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
|
||||||
Each string represents a piece of retrieved information.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If search fails (caught and returns empty list)
|
|
||||||
"""
|
|
||||||
from app.repositories.neo4j.graph_search import (
|
|
||||||
search_graph,
|
|
||||||
search_graph_by_embedding
|
|
||||||
)
|
|
||||||
from app.core.memory.storage_services.search import run_hybrid_search
|
|
||||||
|
|
||||||
contexts_all: List[str] = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if search_type == "embedding":
|
|
||||||
# Embedding-based search
|
|
||||||
search_results = await search_graph_by_embedding(
|
|
||||||
connector=connector,
|
|
||||||
embedder_client=embedder,
|
|
||||||
query_text=question,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
|
||||||
)
|
|
||||||
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
# Build context from chunks
|
|
||||||
for c in chunks:
|
|
||||||
content = str(c.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
# Add statements
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
# Add summaries
|
|
||||||
for sm in summaries:
|
|
||||||
summary_text = str(sm.get("summary", "")).strip()
|
|
||||||
if summary_text:
|
|
||||||
contexts_all.append(summary_text)
|
|
||||||
|
|
||||||
# Add top entities (limit to 3 to avoid noise)
|
|
||||||
if entities:
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = (
|
|
||||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
|
||||||
if scored else entities[:3]
|
|
||||||
)
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(
|
|
||||||
f"EntitySummary: {name}"
|
|
||||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
|
||||||
)
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
elif search_type == "keyword":
|
|
||||||
# Keyword-based search
|
|
||||||
search_results = await search_graph(
|
|
||||||
connector=connector,
|
|
||||||
q=question,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
dialogs = search_results.get("dialogues", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
|
|
||||||
# Build context from dialogues
|
|
||||||
for d in dialogs:
|
|
||||||
content = str(d.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
# Add statements
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
# Add entity names
|
|
||||||
if entities:
|
|
||||||
entity_names = [
|
|
||||||
str(e.get("name", "")).strip()
|
|
||||||
for e in entities[:5]
|
|
||||||
if e.get("name")
|
|
||||||
]
|
|
||||||
if entity_names:
|
|
||||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
|
||||||
|
|
||||||
else: # hybrid
|
|
||||||
# Hybrid search with fallback to embedding
|
|
||||||
try:
|
|
||||||
search_results = await run_hybrid_search(
|
|
||||||
query_text=question,
|
|
||||||
search_type=search_type,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
|
||||||
output_path=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle flat structure (new API format)
|
|
||||||
if search_results and isinstance(search_results, dict):
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
# Check if we got results
|
|
||||||
if not (chunks or statements or entities or summaries):
|
|
||||||
# Try nested structure (backward compatibility)
|
|
||||||
reranked = search_results.get("reranked_results", {})
|
|
||||||
if reranked and isinstance(reranked, dict):
|
|
||||||
chunks = reranked.get("chunks", [])
|
|
||||||
statements = reranked.get("statements", [])
|
|
||||||
entities = reranked.get("entities", [])
|
|
||||||
summaries = reranked.get("summaries", [])
|
|
||||||
else:
|
|
||||||
raise ValueError("Hybrid search returned empty results")
|
|
||||||
else:
|
|
||||||
raise ValueError("Hybrid search returned empty results")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback to embedding search
|
|
||||||
search_results = await search_graph_by_embedding(
|
|
||||||
connector=connector,
|
|
||||||
embedder_client=embedder,
|
|
||||||
query_text=question,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
|
||||||
)
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
# Build context (same for both hybrid and fallback)
|
|
||||||
for c in chunks:
|
|
||||||
content = str(c.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
for sm in summaries:
|
|
||||||
summary_text = str(sm.get("summary", "")).strip()
|
|
||||||
if summary_text:
|
|
||||||
contexts_all.append(summary_text)
|
|
||||||
|
|
||||||
# Add top entities
|
|
||||||
if entities:
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = (
|
|
||||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
|
||||||
if scored else entities[:3]
|
|
||||||
)
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(
|
|
||||||
f"EntitySummary: {name}"
|
|
||||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
|
||||||
)
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Return empty list on error
|
|
||||||
contexts_all = []
|
|
||||||
|
|
||||||
return contexts_all
|
|
||||||
|
|
||||||
|
|
||||||
async def ingest_conversations_if_needed(
|
|
||||||
conversations: List[str],
|
|
||||||
group_id: str,
|
|
||||||
reset: bool = False
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Wrapper for conversation ingestion using external extraction pipeline.
|
|
||||||
|
|
||||||
This function populates the Neo4j database with processed conversation data
|
|
||||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
|
||||||
|
|
||||||
The ingestion process:
|
|
||||||
1. Parses conversation text into dialogue messages
|
|
||||||
2. Chunks the dialogues into semantic units
|
|
||||||
3. Extracts statements and entities using LLM
|
|
||||||
4. Generates embeddings for all content
|
|
||||||
5. Stores everything in Neo4j graph database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversations: List of raw conversation texts from LoCoMo dataset
|
|
||||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
|
||||||
group_id: Target group ID for database storage
|
|
||||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The external function uses "contexts" to mean "conversation texts".
|
|
||||||
This runs the full extraction pipeline: chunking → entity extraction →
|
|
||||||
statement extraction → embedding → Neo4j storage.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
success = await ingest_contexts_via_full_pipeline(
|
|
||||||
contexts=conversations,
|
|
||||||
group_id=group_id,
|
|
||||||
save_chunk_output=True
|
|
||||||
)
|
|
||||||
return success
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
|
||||||
return False
|
|
||||||
@@ -1,878 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import statistics
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except Exception:
|
|
||||||
def load_dotenv():
|
|
||||||
return None
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from app.core.memory.evaluation.common.metrics import (
|
|
||||||
avg_context_tokens,
|
|
||||||
bleu1,
|
|
||||||
jaccard,
|
|
||||||
latency_stats,
|
|
||||||
)
|
|
||||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
|
||||||
from app.core.memory.evaluation.extraction_utils import (
|
|
||||||
ingest_contexts_via_full_pipeline,
|
|
||||||
)
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.memory.storage_services.search import run_hybrid_search
|
|
||||||
from app.core.memory.utils.config.definitions import (
|
|
||||||
PROJECT_ROOT,
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
SELECTED_GROUP_ID,
|
|
||||||
SELECTED_LLM_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
|
|
||||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
|
||||||
def _loc_normalize(text: str) -> str:
|
|
||||||
import re
|
|
||||||
# 确保输入是字符串
|
|
||||||
text = str(text) if text is not None else ""
|
|
||||||
text = text.lower()
|
|
||||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
|
||||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
|
||||||
text = re.sub(r"[^\w\s]", " ", text)
|
|
||||||
text = " ".join(text.split())
|
|
||||||
return text
|
|
||||||
|
|
||||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
|
||||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
|
||||||
import re
|
|
||||||
# 确保输入是字符串
|
|
||||||
t = str(text) if text is not None else ""
|
|
||||||
# today / yesterday / tomorrow
|
|
||||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
# X days ago / in X days
|
|
||||||
def _ago_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor - timedelta(days=n)).date().isoformat()
|
|
||||||
def _in_repl(m: re.Match[str]) -> str:
|
|
||||||
n = int(m.group(1))
|
|
||||||
return (anchor + timedelta(days=n)).date().isoformat()
|
|
||||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
|
||||||
# last week / next week(以7天近似)
|
|
||||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
|
||||||
return t
|
|
||||||
|
|
||||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
|
||||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
|
||||||
# 确保输入是字符串
|
|
||||||
pred_str = str(prediction) if prediction is not None else ""
|
|
||||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
|
||||||
|
|
||||||
p_tokens = _loc_normalize(pred_str).split()
|
|
||||||
g_tokens = _loc_normalize(truth_str).split()
|
|
||||||
if not p_tokens or not g_tokens:
|
|
||||||
return 0.0
|
|
||||||
p = set(p_tokens)
|
|
||||||
g = set(g_tokens)
|
|
||||||
tp = len(p & g)
|
|
||||||
precision = tp / len(p) if p else 0.0
|
|
||||||
recall = tp / len(g) if g else 0.0
|
|
||||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
|
||||||
|
|
||||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
|
||||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
|
||||||
# 确保输入是字符串
|
|
||||||
pred_str = str(prediction) if prediction is not None else ""
|
|
||||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
|
||||||
|
|
||||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
|
||||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
|
||||||
if not predictions or not ground_truths:
|
|
||||||
return 0.0
|
|
||||||
def _f1(a: str, b: str) -> float:
|
|
||||||
return loc_f1_score(a, b)
|
|
||||||
vals = []
|
|
||||||
for gt in ground_truths:
|
|
||||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
|
||||||
return sum(vals) / len(vals)
|
|
||||||
|
|
||||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
|
||||||
CATEGORY_MAP_NUM_TO_NAME = {
|
|
||||||
4: "Single-Hop",
|
|
||||||
1: "Multi-Hop",
|
|
||||||
3: "Open Domain",
|
|
||||||
2: "Temporal",
|
|
||||||
}
|
|
||||||
|
|
||||||
_TYPE_ALIASES = {
|
|
||||||
"single-hop": "Single-Hop",
|
|
||||||
"singlehop": "Single-Hop",
|
|
||||||
"single hop": "Single-Hop",
|
|
||||||
"multi-hop": "Multi-Hop",
|
|
||||||
"multihop": "Multi-Hop",
|
|
||||||
"multi hop": "Multi-Hop",
|
|
||||||
"open domain": "Open Domain",
|
|
||||||
"opendomain": "Open Domain",
|
|
||||||
"temporal": "Temporal",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_category_label(item: Dict[str, Any]) -> str:
|
|
||||||
# 1) 直接用字符串 cat
|
|
||||||
cat = item.get("cat")
|
|
||||||
if isinstance(cat, str) and cat.strip():
|
|
||||||
name = cat.strip()
|
|
||||||
lower = name.lower()
|
|
||||||
return _TYPE_ALIASES.get(lower, name)
|
|
||||||
# 2) 数字 category 转名称
|
|
||||||
cat_num = item.get("category")
|
|
||||||
if isinstance(cat_num, int):
|
|
||||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
|
||||||
# 3) 备用 type 字段
|
|
||||||
t = item.get("type")
|
|
||||||
if isinstance(t, str) and t.strip():
|
|
||||||
lower = t.strip().lower()
|
|
||||||
return _TYPE_ALIASES.get(lower, t.strip())
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
|
||||||
"""基于问题关键词智能选择上下文"""
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 提取问题关键词(只保留有意义的词)
|
|
||||||
question_lower = question.lower()
|
|
||||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
|
||||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
|
||||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
|
||||||
|
|
||||||
print(f"🔍 问题关键词: {question_words}")
|
|
||||||
|
|
||||||
# 给每个上下文打分
|
|
||||||
scored_contexts = []
|
|
||||||
for i, context in enumerate(contexts):
|
|
||||||
context_lower = context.lower()
|
|
||||||
score = 0
|
|
||||||
|
|
||||||
# 关键词匹配得分
|
|
||||||
keyword_matches = 0
|
|
||||||
for word in question_words:
|
|
||||||
if word in context_lower:
|
|
||||||
keyword_matches += 1
|
|
||||||
# 关键词出现次数越多,得分越高
|
|
||||||
score += context_lower.count(word) * 2
|
|
||||||
|
|
||||||
# 上下文长度得分(适中的长度更好)
|
|
||||||
context_len = len(context)
|
|
||||||
if 100 < context_len < 2000: # 理想长度范围
|
|
||||||
score += 5
|
|
||||||
elif context_len >= 2000: # 太长可能包含无关信息
|
|
||||||
score += 2
|
|
||||||
|
|
||||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
|
||||||
if i < 3:
|
|
||||||
score += 3
|
|
||||||
|
|
||||||
scored_contexts.append((score, context, keyword_matches))
|
|
||||||
|
|
||||||
# 按得分排序
|
|
||||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
# 选择高得分的上下文,直到达到字符限制
|
|
||||||
selected = []
|
|
||||||
total_chars = 0
|
|
||||||
selected_count = 0
|
|
||||||
|
|
||||||
print("📊 上下文相关性分析:")
|
|
||||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
|
||||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
|
||||||
|
|
||||||
for score, context, matches in scored_contexts:
|
|
||||||
if total_chars + len(context) <= max_chars:
|
|
||||||
selected.append(context)
|
|
||||||
total_chars += len(context)
|
|
||||||
selected_count += 1
|
|
||||||
else:
|
|
||||||
# 如果这个上下文得分很高但放不下,尝试截取
|
|
||||||
if score > 10 and total_chars < max_chars - 500:
|
|
||||||
remaining = max_chars - total_chars
|
|
||||||
# 找到包含关键词的部分
|
|
||||||
lines = context.split('\n')
|
|
||||||
relevant_lines = []
|
|
||||||
current_chars = 0
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
line_lower = line.lower()
|
|
||||||
line_relevance = any(word in line_lower for word in question_words)
|
|
||||||
|
|
||||||
if line_relevance and current_chars < remaining - 100:
|
|
||||||
relevant_lines.append(line)
|
|
||||||
current_chars += len(line)
|
|
||||||
|
|
||||||
if relevant_lines:
|
|
||||||
truncated = '\n'.join(relevant_lines)
|
|
||||||
if len(truncated) > 100: # 确保有足够内容
|
|
||||||
selected.append(truncated + "\n[相关内容截断...]")
|
|
||||||
total_chars += len(truncated)
|
|
||||||
selected_count += 1
|
|
||||||
break # 不再尝试添加更多上下文
|
|
||||||
|
|
||||||
result = "\n\n".join(selected)
|
|
||||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_search_params_by_category(category: str):
|
|
||||||
"""根据问题类别调整检索参数"""
|
|
||||||
params_map = {
|
|
||||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
|
||||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
|
||||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
|
||||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
|
||||||
}
|
|
||||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
|
||||||
|
|
||||||
|
|
||||||
async def run_locomo_eval(
|
|
||||||
sample_size: int = 1,
|
|
||||||
group_id: str | None = None,
|
|
||||||
search_limit: int = 8,
|
|
||||||
context_char_budget: int = 4000, # 保持默认值不变
|
|
||||||
llm_temperature: float = 0.0,
|
|
||||||
llm_max_tokens: int = 32,
|
|
||||||
search_type: str = "hybrid", # 保持默认值不变
|
|
||||||
output_path: str | None = None,
|
|
||||||
skip_ingest_if_exists: bool = True,
|
|
||||||
llm_timeout: float = 10.0,
|
|
||||||
llm_max_retries: int = 1
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
|
|
||||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
raw = json.load(f)
|
|
||||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
|
||||||
qa_items: List[Dict[str, Any]] = []
|
|
||||||
if isinstance(raw, list):
|
|
||||||
for entry in raw:
|
|
||||||
qa_items.extend(entry.get("qa", []))
|
|
||||||
else:
|
|
||||||
qa_items.extend(raw.get("qa", []))
|
|
||||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
|
||||||
|
|
||||||
# === 保持原来的数据摄入逻辑 ===
|
|
||||||
entries = raw if isinstance(raw, list) else [raw]
|
|
||||||
|
|
||||||
# 只摄入前1条对话(保持原样)
|
|
||||||
max_dialogues_to_ingest = 1
|
|
||||||
contents: List[str] = []
|
|
||||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
|
||||||
|
|
||||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
|
||||||
if not isinstance(entry, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
conv = entry.get("conversation", {})
|
|
||||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
|
||||||
|
|
||||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
|
||||||
|
|
||||||
lines: List[str] = []
|
|
||||||
if isinstance(conv, dict):
|
|
||||||
# 收集所有 session_* 的消息
|
|
||||||
session_count = 0
|
|
||||||
for key, val in conv.items():
|
|
||||||
if isinstance(val, list) and key.startswith("session_"):
|
|
||||||
session_count += 1
|
|
||||||
for msg in val:
|
|
||||||
role = msg.get("speaker") or "用户"
|
|
||||||
text = msg.get("text") or ""
|
|
||||||
text = str(text).strip()
|
|
||||||
if not text:
|
|
||||||
continue
|
|
||||||
lines.append(f"{role}: {text}")
|
|
||||||
|
|
||||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
|
||||||
|
|
||||||
if not lines:
|
|
||||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
|
||||||
continue
|
|
||||||
|
|
||||||
contents.append("\n".join(lines))
|
|
||||||
|
|
||||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
|
||||||
|
|
||||||
# 选择要评测的QA对(从所有对话中选取)
|
|
||||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
|
||||||
if isinstance(raw, list):
|
|
||||||
for e_idx, entry in enumerate(raw):
|
|
||||||
for qa in entry.get("qa", []):
|
|
||||||
indexed_items.append((e_idx, qa))
|
|
||||||
else:
|
|
||||||
for qa in raw.get("qa", []):
|
|
||||||
indexed_items.append((0, qa))
|
|
||||||
|
|
||||||
# 这里使用sample_size来限制评测的QA数量
|
|
||||||
selected = indexed_items[:sample_size]
|
|
||||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
|
||||||
|
|
||||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
|
||||||
# === 修改结束 ===
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# 关键修复:强制重新摄入纯净的对话数据
|
|
||||||
print("🔄 强制重新摄入纯净的对话数据...")
|
|
||||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
|
||||||
|
|
||||||
# 使用异步LLM客户端
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
|
||||||
# 初始化embedder用于直接调用
|
|
||||||
with get_db_context() as db:
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
|
||||||
embedder = OpenAIEmbedderClient(
|
|
||||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
# connector initialized above
|
|
||||||
latencies_llm: List[float] = []
|
|
||||||
latencies_search: List[float] = []
|
|
||||||
# 上下文诊断收集
|
|
||||||
per_query_context_counts: List[int] = []
|
|
||||||
per_query_context_avg_tokens: List[float] = []
|
|
||||||
per_query_context_chars: List[int] = []
|
|
||||||
per_query_context_tokens_total: List[int] = []
|
|
||||||
# 详细样本调试信息
|
|
||||||
samples: List[Dict[str, Any]] = []
|
|
||||||
# 通用指标
|
|
||||||
f1s: List[float] = []
|
|
||||||
b1s: List[float] = []
|
|
||||||
jss: List[float] = []
|
|
||||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
|
||||||
loc_f1s: List[float] = []
|
|
||||||
# Per-category aggregation
|
|
||||||
cat_counts: Dict[str, int] = {}
|
|
||||||
cat_f1s: Dict[str, List[float]] = {}
|
|
||||||
cat_b1s: Dict[str, List[float]] = {}
|
|
||||||
cat_jss: Dict[str, List[float]] = {}
|
|
||||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
|
||||||
try:
|
|
||||||
for item in items:
|
|
||||||
q = item.get("question", "")
|
|
||||||
ref = item.get("answer", "")
|
|
||||||
# 确保答案是字符串
|
|
||||||
ref_str = str(ref) if ref is not None else ""
|
|
||||||
cat = get_category_label(item)
|
|
||||||
|
|
||||||
print(f"\n=== 处理问题: {q} ===")
|
|
||||||
|
|
||||||
# 根据类别调整检索参数
|
|
||||||
search_params = get_search_params_by_category(cat)
|
|
||||||
adjusted_limit = search_params["limit"]
|
|
||||||
max_chars = search_params["max_chars"]
|
|
||||||
|
|
||||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
|
||||||
|
|
||||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
|
||||||
t0 = time.time()
|
|
||||||
contexts_all: List[str] = []
|
|
||||||
search_results = None # 保存完整的检索结果
|
|
||||||
|
|
||||||
try:
|
|
||||||
if search_type == "embedding":
|
|
||||||
# 直接调用嵌入检索,包含三路数据
|
|
||||||
search_results = await search_graph_by_embedding(
|
|
||||||
connector=connector,
|
|
||||||
embedder_client=embedder,
|
|
||||||
query_text=q,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=adjusted_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
|
||||||
)
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
|
||||||
|
|
||||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
|
||||||
for c in chunks:
|
|
||||||
content = str(c.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
for sm in summaries:
|
|
||||||
summary_text = str(sm.get("summary", "")).strip()
|
|
||||||
if summary_text:
|
|
||||||
contexts_all.append(summary_text)
|
|
||||||
|
|
||||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
elif search_type == "keyword":
|
|
||||||
# 直接调用关键词检索
|
|
||||||
search_results = await search_graph(
|
|
||||||
connector=connector,
|
|
||||||
q=q,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=adjusted_limit
|
|
||||||
)
|
|
||||||
dialogs = search_results.get("dialogues", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
|
||||||
|
|
||||||
# 构建上下文
|
|
||||||
for d in dialogs:
|
|
||||||
content = str(d.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
# 实体处理(关键词检索的实体可能没有分数)
|
|
||||||
if entities:
|
|
||||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
|
||||||
if entity_names:
|
|
||||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
|
||||||
|
|
||||||
else: # hybrid
|
|
||||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
|
||||||
print("🔀 使用混合检索(带回退机制)...")
|
|
||||||
try:
|
|
||||||
search_results = await run_hybrid_search(
|
|
||||||
query_text=q,
|
|
||||||
search_type=search_type,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=adjusted_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
|
||||||
output_path=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
|
||||||
# 新的API返回扁平结构,直接从顶层获取结果
|
|
||||||
if search_results and isinstance(search_results, dict):
|
|
||||||
# 新API返回扁平结构:直接从顶层获取
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
|
|
||||||
# 检查是否有有效结果
|
|
||||||
if chunks or statements or entities or summaries:
|
|
||||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
|
||||||
else:
|
|
||||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
|
||||||
reranked = search_results.get("reranked_results", {})
|
|
||||||
if reranked and isinstance(reranked, dict):
|
|
||||||
chunks = reranked.get("chunks", [])
|
|
||||||
statements = reranked.get("statements", [])
|
|
||||||
entities = reranked.get("entities", [])
|
|
||||||
summaries = reranked.get("summaries", [])
|
|
||||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
|
||||||
else:
|
|
||||||
raise ValueError("混合检索返回空结果")
|
|
||||||
else:
|
|
||||||
raise ValueError("混合检索返回空结果")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
|
||||||
search_results = await search_graph_by_embedding(
|
|
||||||
connector=connector,
|
|
||||||
embedder_client=embedder,
|
|
||||||
query_text=q,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=adjusted_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"],
|
|
||||||
)
|
|
||||||
chunks = search_results.get("chunks", [])
|
|
||||||
statements = search_results.get("statements", [])
|
|
||||||
entities = search_results.get("entities", [])
|
|
||||||
summaries = search_results.get("summaries", [])
|
|
||||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
|
||||||
|
|
||||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
|
||||||
for c in chunks:
|
|
||||||
content = str(c.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
contexts_all.append(content)
|
|
||||||
|
|
||||||
for s in statements:
|
|
||||||
stmt_text = str(s.get("statement", "")).strip()
|
|
||||||
if stmt_text:
|
|
||||||
contexts_all.append(stmt_text)
|
|
||||||
|
|
||||||
for sm in summaries:
|
|
||||||
summary_text = str(sm.get("summary", "")).strip()
|
|
||||||
if summary_text:
|
|
||||||
contexts_all.append(summary_text)
|
|
||||||
|
|
||||||
# 实体摘要:最多加入前3个高分实体
|
|
||||||
if entities:
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
|
||||||
filtered_contexts = []
|
|
||||||
for context in contexts_all:
|
|
||||||
content = str(context)
|
|
||||||
# 排除包含当前问题标准答案的上下文
|
|
||||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
|
||||||
print("🚫 过滤掉包含标准答案的上下文")
|
|
||||||
continue
|
|
||||||
filtered_contexts.append(context)
|
|
||||||
|
|
||||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
|
||||||
contexts_all = filtered_contexts
|
|
||||||
|
|
||||||
# 输出完整的检索结果信息
|
|
||||||
print("🔍 检索结果详情:")
|
|
||||||
if search_results:
|
|
||||||
output_data = {
|
|
||||||
"statements": [
|
|
||||||
{
|
|
||||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
|
||||||
"score": s.get("score", 0.0)
|
|
||||||
}
|
|
||||||
for s in (statements[:2] if 'statements' in locals() else [])
|
|
||||||
],
|
|
||||||
"dialogues": [
|
|
||||||
{
|
|
||||||
"uuid": d.get("uuid", ""),
|
|
||||||
"group_id": d.get("group_id", ""),
|
|
||||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
|
||||||
"score": d.get("score", 0.0)
|
|
||||||
}
|
|
||||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
|
||||||
],
|
|
||||||
"entities": [
|
|
||||||
{
|
|
||||||
"name": e.get("name", ""),
|
|
||||||
"entity_type": e.get("entity_type", ""),
|
|
||||||
"score": e.get("score", 0.0)
|
|
||||||
}
|
|
||||||
for e in (entities[:2] if 'entities' in locals() else [])
|
|
||||||
]
|
|
||||||
}
|
|
||||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
|
||||||
else:
|
|
||||||
print(" 无检索结果")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ {search_type}检索失败: {e}")
|
|
||||||
contexts_all = []
|
|
||||||
search_results = None
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
latencies_search.append((t1 - t0) * 1000)
|
|
||||||
|
|
||||||
# 使用智能上下文选择
|
|
||||||
context_text = ""
|
|
||||||
if contexts_all:
|
|
||||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
|
||||||
|
|
||||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
|
||||||
if len(context_text) > max_chars:
|
|
||||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
|
||||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
|
||||||
|
|
||||||
# 时间解析
|
|
||||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
|
||||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
|
||||||
|
|
||||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
|
||||||
|
|
||||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
|
||||||
|
|
||||||
# 显示不同上下文的预览
|
|
||||||
print("🔍 上下文预览:")
|
|
||||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
|
||||||
preview = context[:150].replace('\n', ' ')
|
|
||||||
print(f" 上下文{j+1}: {preview}...")
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("❌ 没有检索到有效上下文")
|
|
||||||
context_text = "No relevant context found."
|
|
||||||
|
|
||||||
# 记录上下文诊断信息
|
|
||||||
per_query_context_counts.append(len(contexts_all))
|
|
||||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
|
||||||
per_query_context_chars.append(len(context_text))
|
|
||||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
|
||||||
|
|
||||||
# LLM 提示词
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": (
|
|
||||||
"You are a precise QA assistant. Answer following these rules:\n"
|
|
||||||
"1) Extract the EXACT information mentioned in the context\n"
|
|
||||||
"2) For time questions: calculate actual dates from relative times\n"
|
|
||||||
"3) Return ONLY the answer text in simplest form\n"
|
|
||||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
|
||||||
"5) If no clear answer found, respond with 'Unknown'"
|
|
||||||
)},
|
|
||||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
|
||||||
]
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
# 使用异步调用
|
|
||||||
resp = await llm_client.chat(messages=messages)
|
|
||||||
t3 = time.time()
|
|
||||||
latencies_llm.append((t3 - t2) * 1000)
|
|
||||||
|
|
||||||
# 兼容不同的响应格式
|
|
||||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
|
||||||
|
|
||||||
# 计算指标(确保使用字符串)
|
|
||||||
f1_val = common_f1(str(pred), ref_str)
|
|
||||||
b1_val = bleu1(str(pred), ref_str)
|
|
||||||
j_val = jaccard(str(pred), ref_str)
|
|
||||||
|
|
||||||
f1s.append(f1_val)
|
|
||||||
b1s.append(b1_val)
|
|
||||||
jss.append(j_val)
|
|
||||||
|
|
||||||
# Accumulate by category
|
|
||||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
|
||||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
|
||||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
|
||||||
cat_jss.setdefault(cat, []).append(j_val)
|
|
||||||
|
|
||||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
|
||||||
if item.get("category") in [2, 3, 4]:
|
|
||||||
loc_val = loc_f1_score(str(pred), ref_str)
|
|
||||||
elif item.get("category") in [1]:
|
|
||||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
|
||||||
else:
|
|
||||||
loc_val = loc_f1_score(str(pred), ref_str)
|
|
||||||
loc_f1s.append(loc_val)
|
|
||||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
|
||||||
|
|
||||||
# 保存完整的检索结果信息
|
|
||||||
samples.append({
|
|
||||||
"question": q,
|
|
||||||
"answer": ref_str,
|
|
||||||
"category": cat,
|
|
||||||
"prediction": pred,
|
|
||||||
"metrics": {
|
|
||||||
"f1": f1_val,
|
|
||||||
"b1": b1_val,
|
|
||||||
"j": j_val,
|
|
||||||
"loc_f1": loc_val
|
|
||||||
},
|
|
||||||
"retrieval": {
|
|
||||||
"retrieved_documents": len(contexts_all),
|
|
||||||
"context_length": len(context_text),
|
|
||||||
"search_limit": adjusted_limit,
|
|
||||||
"max_chars": max_chars
|
|
||||||
},
|
|
||||||
"timing": {
|
|
||||||
"search_ms": (t1 - t0) * 1000,
|
|
||||||
"llm_ms": (t3 - t2) * 1000
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"🤖 LLM 回答: {pred}")
|
|
||||||
print(f"✅ 正确答案: {ref_str}")
|
|
||||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
|
||||||
|
|
||||||
# Compute per-category averages and dispersion (std, iqr)
|
|
||||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
|
||||||
if not sorted_vals:
|
|
||||||
return 0.0
|
|
||||||
if len(sorted_vals) == 1:
|
|
||||||
return sorted_vals[0]
|
|
||||||
k = (len(sorted_vals) - 1) * p
|
|
||||||
f = int(k)
|
|
||||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
|
||||||
if f == c:
|
|
||||||
return sorted_vals[f]
|
|
||||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
|
||||||
|
|
||||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
|
||||||
for c in cat_counts:
|
|
||||||
f_list = cat_f1s.get(c, [])
|
|
||||||
b_list = cat_b1s.get(c, [])
|
|
||||||
j_list = cat_jss.get(c, [])
|
|
||||||
lf_list = cat_loc_f1s.get(c, [])
|
|
||||||
j_sorted = sorted(j_list)
|
|
||||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
|
||||||
j_q75 = _percentile(j_sorted, 0.75)
|
|
||||||
j_q25 = _percentile(j_sorted, 0.25)
|
|
||||||
by_category[c] = {
|
|
||||||
"count": cat_counts[c],
|
|
||||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
|
||||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
|
||||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
|
||||||
"j_std": j_std,
|
|
||||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
|
||||||
# 参考 LoCoMo 评测的类别专用 F1
|
|
||||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
|
||||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"dataset": "locomo",
|
|
||||||
"items": len(items),
|
|
||||||
"metrics": {
|
|
||||||
"f1": sum(f1s) / max(len(f1s), 1),
|
|
||||||
"b1": sum(b1s) / max(len(b1s), 1),
|
|
||||||
"j": sum(jss) / max(len(jss), 1),
|
|
||||||
# LoCoMo 类别专用 F1 的总体
|
|
||||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
|
||||||
},
|
|
||||||
"by_category": by_category,
|
|
||||||
"category_counts": cat_counts,
|
|
||||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
|
||||||
"context": {
|
|
||||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
|
||||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
|
||||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
|
||||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
|
||||||
},
|
|
||||||
"latency": {
|
|
||||||
"search": latency_stats(latencies_search),
|
|
||||||
"llm": latency_stats(latencies_llm),
|
|
||||||
},
|
|
||||||
"samples": samples,
|
|
||||||
"params": {
|
|
||||||
"group_id": group_id,
|
|
||||||
"search_limit": search_limit,
|
|
||||||
"context_char_budget": context_char_budget,
|
|
||||||
"search_type": search_type,
|
|
||||||
"llm_id": SELECTED_LLM_ID,
|
|
||||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
|
||||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
|
||||||
"llm_timeout": llm_timeout,
|
|
||||||
"llm_max_retries": llm_max_retries,
|
|
||||||
"llm_temperature": llm_temperature,
|
|
||||||
"llm_max_tokens": llm_max_tokens
|
|
||||||
},
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
if output_path:
|
|
||||||
try:
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"✅ 结果已保存到: {output_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 保存结果失败: {e}")
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
|
||||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
|
||||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
|
||||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
|
||||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
|
||||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
|
||||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
|
||||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
|
||||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
|
||||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
|
||||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
|
||||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
result = asyncio.run(run_locomo_eval(
|
|
||||||
sample_size=args.sample_size,
|
|
||||||
group_id=args.group_id,
|
|
||||||
search_limit=args.search_limit,
|
|
||||||
context_char_budget=args.context_char_budget,
|
|
||||||
llm_temperature=args.llm_temperature,
|
|
||||||
llm_max_tokens=args.llm_max_tokens,
|
|
||||||
search_type=args.search_type,
|
|
||||||
output_path=args.output_path,
|
|
||||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
|
||||||
llm_timeout=args.llm_timeout,
|
|
||||||
llm_max_retries=args.llm_max_retries
|
|
||||||
))
|
|
||||||
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("📊 最终评测结果:")
|
|
||||||
print(f" 样本数量: {result['items']}")
|
|
||||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
|
||||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
|
||||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
|
||||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
|
||||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
|
||||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
|
||||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
|
||||||
|
|
||||||
if result['by_category']:
|
|
||||||
print("\n📈 按类别细分:")
|
|
||||||
for cat, metrics in result['by_category'].items():
|
|
||||||
print(f" {cat}:")
|
|
||||||
print(f" 样本数: {metrics['count']}")
|
|
||||||
print(f" F1: {metrics['f1']:.3f}")
|
|
||||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
|
||||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,324 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except Exception:
|
|
||||||
def load_dotenv():
|
|
||||||
return None
|
|
||||||
|
|
||||||
from app.core.memory.evaluation.common.metrics import (
|
|
||||||
avg_context_tokens,
|
|
||||||
exact_match,
|
|
||||||
latency_stats,
|
|
||||||
)
|
|
||||||
from app.core.memory.evaluation.extraction_utils import (
|
|
||||||
ingest_contexts_via_full_pipeline,
|
|
||||||
)
|
|
||||||
from app.core.memory.storage_services.search import run_hybrid_search
|
|
||||||
from app.core.memory.utils.config.definitions import (
|
|
||||||
PROJECT_ROOT,
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
SELECTED_GROUP_ID,
|
|
||||||
SELECTED_LLM_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
|
|
||||||
|
|
||||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
|
||||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
import re
|
|
||||||
# 提取问题关键词(移除停用词)
|
|
||||||
question_lower = (question or "").lower()
|
|
||||||
stop_words = {
|
|
||||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
|
||||||
'the','a','an','and','or','but'
|
|
||||||
}
|
|
||||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
|
||||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
|
||||||
|
|
||||||
# 评分
|
|
||||||
scored = []
|
|
||||||
for i, ctx in enumerate(contexts):
|
|
||||||
ctx_lower = (ctx or "").lower()
|
|
||||||
score = 0
|
|
||||||
matches = 0
|
|
||||||
for w in question_words:
|
|
||||||
if w in ctx_lower:
|
|
||||||
matches += 1
|
|
||||||
score += ctx_lower.count(w) * 2
|
|
||||||
length = len(ctx)
|
|
||||||
if 100 < length < 2000:
|
|
||||||
score += 5
|
|
||||||
elif length >= 2000:
|
|
||||||
score += 2
|
|
||||||
if i < 3:
|
|
||||||
score += 3
|
|
||||||
scored.append((score, ctx, matches))
|
|
||||||
|
|
||||||
scored.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
|
||||||
selected: List[str] = []
|
|
||||||
total = 0
|
|
||||||
for score, ctx, _ in scored:
|
|
||||||
if total + len(ctx) <= max_chars:
|
|
||||||
selected.append(ctx)
|
|
||||||
total += len(ctx)
|
|
||||||
else:
|
|
||||||
if score > 10 and total < max_chars - 200:
|
|
||||||
remaining = max_chars - total
|
|
||||||
lines = ctx.split('\n')
|
|
||||||
rel_lines: List[str] = []
|
|
||||||
cur = 0
|
|
||||||
for line in lines:
|
|
||||||
l = line.lower()
|
|
||||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
|
||||||
rel_lines.append(line)
|
|
||||||
cur += len(line)
|
|
||||||
if rel_lines:
|
|
||||||
truncated = '\n'.join(rel_lines)
|
|
||||||
if len(truncated) > 50:
|
|
||||||
selected.append(truncated + "\n[相关内容截断...]")
|
|
||||||
total += len(truncated)
|
|
||||||
break
|
|
||||||
return "\n\n".join(selected)
|
|
||||||
|
|
||||||
|
|
||||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
|
||||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
|
||||||
parts: List[str] = []
|
|
||||||
for turn in dialog_obj.get("dialog", []):
|
|
||||||
speaker = turn.get("speaker", "")
|
|
||||||
text = turn.get("text", "")
|
|
||||||
if text:
|
|
||||||
parts.append(f"{speaker}: {text}")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
||||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
|
||||||
if results is None:
|
|
||||||
return []
|
|
||||||
emb = []
|
|
||||||
kw = []
|
|
||||||
if isinstance(results.get("embedding_search"), dict):
|
|
||||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
|
||||||
elif isinstance(results.get("dialogues"), list):
|
|
||||||
emb = results.get("dialogues", []) or []
|
|
||||||
if isinstance(results.get("keyword_search"), dict):
|
|
||||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
|
||||||
seen = set()
|
|
||||||
merged: List[Dict[str, Any]] = []
|
|
||||||
for d in emb:
|
|
||||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
|
||||||
if k not in seen:
|
|
||||||
merged.append(d)
|
|
||||||
seen.add(k)
|
|
||||||
for d in kw:
|
|
||||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
|
||||||
if k not in seen:
|
|
||||||
merged.append(d)
|
|
||||||
seen.add(k)
|
|
||||||
return merged
|
|
||||||
|
|
||||||
|
|
||||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
|
||||||
# Load data
|
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
|
||||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
|
||||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
|
||||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
|
||||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
|
||||||
|
|
||||||
# LLM client (使用异步调用)
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
|
||||||
|
|
||||||
# Evaluate each item
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
latencies_llm: List[float] = []
|
|
||||||
latencies_search: List[float] = []
|
|
||||||
contexts_used: List[str] = []
|
|
||||||
correct_flags: List[float] = []
|
|
||||||
f1s: List[float] = []
|
|
||||||
b1s: List[float] = []
|
|
||||||
jss: List[float] = []
|
|
||||||
try:
|
|
||||||
for item in items:
|
|
||||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
|
||||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
|
||||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
|
||||||
t0 = time.time()
|
|
||||||
try:
|
|
||||||
results = await run_hybrid_search(
|
|
||||||
query_text=question,
|
|
||||||
search_type=search_type,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["dialogues", "statements", "entities"],
|
|
||||||
output_path=None,
|
|
||||||
memory_config=memory_config,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
results = None
|
|
||||||
t1 = time.time()
|
|
||||||
latencies_search.append((t1 - t0) * 1000)
|
|
||||||
|
|
||||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
|
||||||
contexts_all: List[str] = []
|
|
||||||
if results:
|
|
||||||
if search_type == "hybrid":
|
|
||||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
|
||||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
|
||||||
emb_dialogs = emb.get("dialogues", [])
|
|
||||||
emb_statements = emb.get("statements", [])
|
|
||||||
emb_entities = emb.get("entities", [])
|
|
||||||
kw_dialogs = kw.get("dialogues", [])
|
|
||||||
kw_statements = kw.get("statements", [])
|
|
||||||
kw_entities = kw.get("entities", [])
|
|
||||||
all_dialogs = emb_dialogs + kw_dialogs
|
|
||||||
all_statements = emb_statements + kw_statements
|
|
||||||
all_entities = emb_entities + kw_entities
|
|
||||||
|
|
||||||
# 简单去重与限制
|
|
||||||
seen_texts = set()
|
|
||||||
for d in all_dialogs:
|
|
||||||
text = str(d.get("content", "")).strip()
|
|
||||||
if text and text not in seen_texts:
|
|
||||||
contexts_all.append(text)
|
|
||||||
seen_texts.add(text)
|
|
||||||
if len(contexts_all) >= search_limit:
|
|
||||||
break
|
|
||||||
for s in all_statements:
|
|
||||||
text = str(s.get("statement", "")).strip()
|
|
||||||
if text and text not in seen_texts:
|
|
||||||
contexts_all.append(text)
|
|
||||||
seen_texts.add(text)
|
|
||||||
if len(contexts_all) >= search_limit:
|
|
||||||
break
|
|
||||||
# 实体摘要(最多3个)
|
|
||||||
names = []
|
|
||||||
merged_entities = all_entities[:]
|
|
||||||
for e in merged_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
if name and name not in names:
|
|
||||||
names.append(name)
|
|
||||||
if len(names) >= 3:
|
|
||||||
break
|
|
||||||
if names:
|
|
||||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
|
||||||
else:
|
|
||||||
dialogs = results.get("dialogues", [])
|
|
||||||
statements = results.get("statements", [])
|
|
||||||
entities = results.get("entities", [])
|
|
||||||
for d in dialogs:
|
|
||||||
text = str(d.get("content", "")).strip()
|
|
||||||
if text:
|
|
||||||
contexts_all.append(text)
|
|
||||||
for s in statements:
|
|
||||||
text = str(s.get("statement", "")).strip()
|
|
||||||
if text:
|
|
||||||
contexts_all.append(text)
|
|
||||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
|
||||||
if names:
|
|
||||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
|
||||||
|
|
||||||
# 智能选择并截断到预算
|
|
||||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
|
||||||
if not context_text:
|
|
||||||
context_text = "No relevant context found."
|
|
||||||
contexts_used.append(context_text[:200])
|
|
||||||
|
|
||||||
# Call LLM (使用异步调用)
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
|
||||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
|
||||||
]
|
|
||||||
t2 = time.time()
|
|
||||||
resp = await llm_client.chat(messages=messages)
|
|
||||||
t3 = time.time()
|
|
||||||
latencies_llm.append((t3 - t2) * 1000)
|
|
||||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
|
||||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
|
||||||
correct_flags.append(exact_match(pred, reference))
|
|
||||||
from app.core.memory.evaluation.common.metrics import (
|
|
||||||
bleu1,
|
|
||||||
f1_score,
|
|
||||||
jaccard,
|
|
||||||
)
|
|
||||||
f1s.append(f1_score(str(pred), str(reference)))
|
|
||||||
b1s.append(bleu1(str(pred), str(reference)))
|
|
||||||
jss.append(jaccard(str(pred), str(reference)))
|
|
||||||
|
|
||||||
# Aggregate metrics
|
|
||||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
|
||||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
|
||||||
result = {
|
|
||||||
"dataset": "memsciqa",
|
|
||||||
"items": len(items),
|
|
||||||
"metrics": {
|
|
||||||
"accuracy": acc,
|
|
||||||
# Placeholders for extensibility
|
|
||||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
|
||||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
|
||||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
|
||||||
},
|
|
||||||
"latency": {
|
|
||||||
"search": latency_stats(latencies_search),
|
|
||||||
"llm": latency_stats(latencies_llm),
|
|
||||||
},
|
|
||||||
"avg_context_tokens": ctx_avg_tokens,
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
load_dotenv()
|
|
||||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
|
||||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
|
||||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
|
||||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
|
||||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
|
||||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
|
||||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
|
||||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
result = asyncio.run(
|
|
||||||
run_memsciqa_eval(
|
|
||||||
sample_size=args.sample_size,
|
|
||||||
group_id=args.group_id,
|
|
||||||
search_limit=args.search_limit,
|
|
||||||
context_char_budget=args.context_char_budget,
|
|
||||||
llm_temperature=args.llm_temperature,
|
|
||||||
llm_max_tokens=args.llm_max_tokens,
|
|
||||||
search_type=args.search_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,576 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except Exception:
|
|
||||||
def load_dotenv():
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 路径与模块导入保持与现有评估脚本一致
|
|
||||||
import sys
|
|
||||||
|
|
||||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
|
||||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
|
||||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
|
||||||
if _p not in sys.path:
|
|
||||||
sys.path.insert(0, _p)
|
|
||||||
|
|
||||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
|
||||||
from app.core.memory.evaluation.common.metrics import (
|
|
||||||
avg_context_tokens,
|
|
||||||
exact_match,
|
|
||||||
latency_stats,
|
|
||||||
)
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|
||||||
from app.core.memory.utils.config.definitions import (
|
|
||||||
PROJECT_ROOT,
|
|
||||||
SELECTED_EMBEDDING_ID,
|
|
||||||
SELECTED_GROUP_ID,
|
|
||||||
SELECTED_LLM_ID,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
|
||||||
except Exception:
|
|
||||||
# 兜底:简单实现(必要时)
|
|
||||||
def f1_score(pred: str, ref: str) -> float:
|
|
||||||
ps = pred.lower().split()
|
|
||||||
rs = ref.lower().split()
|
|
||||||
if not ps or not rs:
|
|
||||||
return 0.0
|
|
||||||
tp = len(set(ps) & set(rs))
|
|
||||||
if tp == 0:
|
|
||||||
return 0.0
|
|
||||||
precision = tp / len(ps)
|
|
||||||
recall = tp / len(rs)
|
|
||||||
if precision + recall == 0:
|
|
||||||
return 0.0
|
|
||||||
return 2 * precision * recall / (precision + recall)
|
|
||||||
|
|
||||||
def bleu1(pred: str, ref: str) -> float:
|
|
||||||
ps = pred.lower().split()
|
|
||||||
rs = ref.lower().split()
|
|
||||||
if not ps or not rs:
|
|
||||||
return 0.0
|
|
||||||
overlap = len([w for w in ps if w in rs])
|
|
||||||
return overlap / max(len(ps), 1)
|
|
||||||
|
|
||||||
def jaccard(pred: str, ref: str) -> float:
|
|
||||||
ps = set(pred.lower().split())
|
|
||||||
rs = set(ref.lower().split())
|
|
||||||
union = len(ps | rs)
|
|
||||||
if union == 0:
|
|
||||||
return 0.0
|
|
||||||
return len(ps & rs) / union
|
|
||||||
|
|
||||||
|
|
||||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
|
||||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
|
||||||
|
|
||||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
|
||||||
"""
|
|
||||||
if not contexts:
|
|
||||||
return ""
|
|
||||||
question_lower = (question or "").lower()
|
|
||||||
stop_words = {
|
|
||||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
|
||||||
'the','a','an','and','or','but'
|
|
||||||
}
|
|
||||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
|
||||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
|
||||||
|
|
||||||
scored = []
|
|
||||||
for i, ctx in enumerate(contexts):
|
|
||||||
ctx_lower = (ctx or "").lower()
|
|
||||||
score = 0
|
|
||||||
matches = 0
|
|
||||||
for w in question_words:
|
|
||||||
if w in ctx_lower:
|
|
||||||
matches += 1
|
|
||||||
score += ctx_lower.count(w) * 2
|
|
||||||
length = len(ctx)
|
|
||||||
if 100 < length < 2000:
|
|
||||||
score += 5
|
|
||||||
elif length >= 2000:
|
|
||||||
score += 2
|
|
||||||
if i < 3:
|
|
||||||
score += 3
|
|
||||||
scored.append((score, ctx, matches))
|
|
||||||
|
|
||||||
scored.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
selected: List[str] = []
|
|
||||||
total = 0
|
|
||||||
for score, ctx, _ in scored:
|
|
||||||
if total + len(ctx) <= max_chars:
|
|
||||||
selected.append(ctx)
|
|
||||||
total += len(ctx)
|
|
||||||
else:
|
|
||||||
if score > 10 and total < max_chars - 200:
|
|
||||||
remaining = max_chars - total
|
|
||||||
lines = ctx.split('\n')
|
|
||||||
rel_lines: List[str] = []
|
|
||||||
cur = 0
|
|
||||||
for line in lines:
|
|
||||||
l = line.lower()
|
|
||||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
|
||||||
rel_lines.append(line)
|
|
||||||
cur += len(line)
|
|
||||||
if rel_lines:
|
|
||||||
truncated = '\n'.join(rel_lines)
|
|
||||||
if len(truncated) > 50:
|
|
||||||
selected.append(truncated + "\n[相关内容截断...]")
|
|
||||||
total += len(truncated)
|
|
||||||
break
|
|
||||||
return "\n\n".join(selected)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
|
||||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
|
||||||
ql = (question or "").lower()
|
|
||||||
stop_words = {
|
|
||||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
|
||||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
|
||||||
}
|
|
||||||
words = re.findall(r"\b[\w-]+\b", ql)
|
|
||||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
|
||||||
# 去重保序
|
|
||||||
seen = set()
|
|
||||||
uniq = []
|
|
||||||
for w in kws:
|
|
||||||
if w not in seen:
|
|
||||||
uniq.append(w)
|
|
||||||
seen.add(w)
|
|
||||||
if len(uniq) >= max_keywords:
|
|
||||||
break
|
|
||||||
return uniq
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
|
||||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
|
||||||
|
|
||||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for ctx in contexts:
|
|
||||||
tl = (ctx or "").lower()
|
|
||||||
match_count = sum(1 for k in keywords if k in tl)
|
|
||||||
length = len(ctx)
|
|
||||||
score = match_count * 200 + min(length, 100000) / 100.0
|
|
||||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
|
||||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
|
||||||
return results[:max(top_n, 0)]
|
|
||||||
|
|
||||||
|
|
||||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
|
||||||
if not os.path.exists(data_path):
|
|
||||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
|
||||||
items: List[Dict[str, Any]] = []
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
items.append(json.loads(line))
|
|
||||||
except Exception:
|
|
||||||
# 跳过坏行但不中断
|
|
||||||
continue
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
async def run_memsciqa_test(
|
|
||||||
sample_size: int = 3,
|
|
||||||
group_id: str | None = None,
|
|
||||||
search_limit: int = 8,
|
|
||||||
context_char_budget: int = 4000,
|
|
||||||
llm_temperature: float = 0.0,
|
|
||||||
llm_max_tokens: int = 64,
|
|
||||||
search_type: str = "embedding",
|
|
||||||
data_path: str | None = None,
|
|
||||||
start_index: int = 0,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
|
||||||
|
|
||||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
|
||||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
|
||||||
- 支持 keyword / embedding / hybrid 三种检索
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 默认使用指定的 memsci 组 ID
|
|
||||||
group_id = group_id or "group_memsci"
|
|
||||||
|
|
||||||
# 数据路径解析(项目根与当前工作目录兜底)
|
|
||||||
if not data_path:
|
|
||||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
|
||||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
|
||||||
if os.path.exists(proj_path):
|
|
||||||
data_path = proj_path
|
|
||||||
elif os.path.exists(cwd_path):
|
|
||||||
data_path = cwd_path
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
|
||||||
|
|
||||||
# 加载数据
|
|
||||||
all_items = load_dataset_memsciqa(data_path)
|
|
||||||
if sample_size is None or sample_size <= 0:
|
|
||||||
items = all_items[start_index:]
|
|
||||||
else:
|
|
||||||
items = all_items[start_index:start_index + sample_size]
|
|
||||||
|
|
||||||
# 初始化 LLM(纯测试:不进行摄入)
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
|
||||||
|
|
||||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
embedder = None
|
|
||||||
if search_type in ("embedding", "hybrid"):
|
|
||||||
with get_db_context() as db:
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
|
||||||
embedder = OpenAIEmbedderClient(
|
|
||||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 评估循环
|
|
||||||
latencies_llm: List[float] = []
|
|
||||||
latencies_search: List[float] = []
|
|
||||||
# 存储完整上下文文本用于统计
|
|
||||||
contexts_used: List[str] = []
|
|
||||||
per_query_context_chars: List[int] = []
|
|
||||||
per_query_context_counts: List[int] = []
|
|
||||||
correct_flags: List[float] = []
|
|
||||||
f1s: List[float] = []
|
|
||||||
b1s: List[float] = []
|
|
||||||
jss: List[float] = []
|
|
||||||
samples: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
total_items = len(items)
|
|
||||||
for idx, item in enumerate(items):
|
|
||||||
if verbose:
|
|
||||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
|
||||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
|
||||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
|
||||||
|
|
||||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
|
||||||
t0 = time.time()
|
|
||||||
results = None
|
|
||||||
try:
|
|
||||||
if search_type in ("embedding", "hybrid"):
|
|
||||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
|
||||||
results = await search_graph_by_embedding(
|
|
||||||
connector=connector,
|
|
||||||
embedder_client=embedder,
|
|
||||||
query_text=question,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
|
||||||
)
|
|
||||||
elif search_type == "keyword":
|
|
||||||
# 关键词检索(直接调用 graph_search)
|
|
||||||
results = await search_graph(
|
|
||||||
connector=connector,
|
|
||||||
q=question,
|
|
||||||
group_id=group_id,
|
|
||||||
limit=search_limit,
|
|
||||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
results = None
|
|
||||||
t1 = time.time()
|
|
||||||
search_ms = (t1 - t0) * 1000
|
|
||||||
latencies_search.append(search_ms)
|
|
||||||
|
|
||||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
|
||||||
contexts_all: List[str] = []
|
|
||||||
retrieved_counts: Dict[str, int] = {}
|
|
||||||
if results:
|
|
||||||
chunks = results.get("chunks", [])
|
|
||||||
statements = results.get("statements", [])
|
|
||||||
entities = results.get("entities", [])
|
|
||||||
summaries = results.get("summaries", [])
|
|
||||||
retrieved_counts = {
|
|
||||||
"chunks": len(chunks),
|
|
||||||
"statements": len(statements),
|
|
||||||
"entities": len(entities),
|
|
||||||
"summaries": len(summaries),
|
|
||||||
}
|
|
||||||
# 优先使用 chunks
|
|
||||||
for c in chunks:
|
|
||||||
text = str(c.get("content", "")).strip()
|
|
||||||
if text:
|
|
||||||
contexts_all.append(text)
|
|
||||||
# 然后是 statements
|
|
||||||
for s in statements:
|
|
||||||
text = str(s.get("statement", "")).strip()
|
|
||||||
if text:
|
|
||||||
contexts_all.append(text)
|
|
||||||
# 然后是 summaries
|
|
||||||
for sm in summaries:
|
|
||||||
text = str(sm.get("summary", "")).strip()
|
|
||||||
if text:
|
|
||||||
contexts_all.append(text)
|
|
||||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
|
||||||
scored = [e for e in entities if e.get("score") is not None]
|
|
||||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
|
||||||
if top_entities:
|
|
||||||
summary_lines = []
|
|
||||||
for e in top_entities:
|
|
||||||
name = str(e.get("name", "")).strip()
|
|
||||||
etype = str(e.get("entity_type", "")).strip()
|
|
||||||
score = e.get("score")
|
|
||||||
if name:
|
|
||||||
meta = []
|
|
||||||
if etype:
|
|
||||||
meta.append(f"type={etype}")
|
|
||||||
if isinstance(score, (int, float)):
|
|
||||||
meta.append(f"score={score:.3f}")
|
|
||||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
|
||||||
if summary_lines:
|
|
||||||
contexts_all.append("\n".join(summary_lines))
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
if retrieved_counts:
|
|
||||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
|
||||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
|
||||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
|
||||||
if q_keywords:
|
|
||||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
|
||||||
if contexts_all:
|
|
||||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
|
||||||
if analysis:
|
|
||||||
print("📊 上下文相关性分析:")
|
|
||||||
for a in analysis:
|
|
||||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
|
||||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
|
||||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
|
||||||
for i, ctx in enumerate(contexts_all[:10]):
|
|
||||||
preview = str(ctx).replace("\n", " ")
|
|
||||||
if len(preview) > 300:
|
|
||||||
preview = preview[:300] + "..."
|
|
||||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
|
||||||
# 标注参考答案是否出现在任一上下文中
|
|
||||||
ref_lower = (str(reference) or "").lower()
|
|
||||||
if ref_lower:
|
|
||||||
hits = []
|
|
||||||
for i, ctx in enumerate(contexts_all):
|
|
||||||
if ref_lower in str(ctx).lower():
|
|
||||||
hits.append(i+1)
|
|
||||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
|
||||||
|
|
||||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
|
||||||
if not context_text:
|
|
||||||
context_text = "No relevant context found."
|
|
||||||
contexts_used.append(context_text)
|
|
||||||
per_query_context_chars.append(len(context_text))
|
|
||||||
per_query_context_counts.append(len(contexts_all))
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
|
||||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
|
||||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
|
||||||
concat_preview = context_text.replace("\n", " ")
|
|
||||||
if len(concat_preview) > 600:
|
|
||||||
concat_preview = concat_preview[:600] + "..."
|
|
||||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
|
||||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
|
||||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
|
||||||
"3) Keep your answer brief and to the point;\n"
|
|
||||||
"4) Do not add explanations or additional text beyond the answer."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
|
||||||
]
|
|
||||||
|
|
||||||
t2 = time.time()
|
|
||||||
try:
|
|
||||||
# 使用异步调用
|
|
||||||
resp = await llm.chat(messages=messages)
|
|
||||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
|
||||||
if hasattr(resp, 'content'):
|
|
||||||
pred = resp.content.strip()
|
|
||||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
|
||||||
pred = resp["choices"][0]["message"]["content"].strip()
|
|
||||||
elif isinstance(resp, dict) and "content" in resp:
|
|
||||||
pred = resp["content"].strip()
|
|
||||||
elif isinstance(resp, str):
|
|
||||||
pred = resp.strip()
|
|
||||||
else:
|
|
||||||
pred = "Unknown"
|
|
||||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
|
||||||
|
|
||||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
|
||||||
if pred.lower() in ["unknown", ""]:
|
|
||||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
|
||||||
ref_lower = (str(reference) or "").lower()
|
|
||||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
|
||||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
|
||||||
except Exception as e:
|
|
||||||
# 更详细的错误处理
|
|
||||||
pred = "Unknown"
|
|
||||||
print(f"⚠️ LLM调用异常: {e}")
|
|
||||||
t3 = time.time()
|
|
||||||
llm_ms = (t3 - t2) * 1000
|
|
||||||
latencies_llm.append(llm_ms)
|
|
||||||
|
|
||||||
exact = exact_match(pred, reference)
|
|
||||||
correct_flags.append(exact)
|
|
||||||
f1_val = f1_score(str(pred), str(reference))
|
|
||||||
b1_val = bleu1(str(pred), str(reference))
|
|
||||||
j_val = jaccard(str(pred), str(reference))
|
|
||||||
f1s.append(f1_val)
|
|
||||||
b1s.append(b1_val)
|
|
||||||
jss.append(j_val)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"🤖 LLM 回答: {pred}")
|
|
||||||
print(f"✅ 正确答案: {reference}")
|
|
||||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
|
||||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
|
||||||
|
|
||||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
|
||||||
samples.append({
|
|
||||||
"question": str(question),
|
|
||||||
"answer": str(reference),
|
|
||||||
"prediction": str(pred),
|
|
||||||
"metrics": {
|
|
||||||
"f1": f1_val,
|
|
||||||
"b1": b1_val,
|
|
||||||
"j": j_val
|
|
||||||
},
|
|
||||||
"retrieval": {
|
|
||||||
"retrieved_documents": len(contexts_all),
|
|
||||||
"context_length": len(context_text),
|
|
||||||
"search_limit": search_limit,
|
|
||||||
"max_chars": context_char_budget
|
|
||||||
},
|
|
||||||
"timing": {
|
|
||||||
"search_ms": search_ms,
|
|
||||||
"llm_ms": llm_ms
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 计算总体指标与聚合
|
|
||||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
|
||||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
|
||||||
result = {
|
|
||||||
"dataset": "memsciqa",
|
|
||||||
"items": len(items),
|
|
||||||
"metrics": {
|
|
||||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
|
||||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
|
||||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"avg_tokens": ctx_avg_tokens,
|
|
||||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
|
||||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
|
||||||
"avg_memory_tokens": 0.0
|
|
||||||
},
|
|
||||||
"latency": {
|
|
||||||
"search": latency_stats(latencies_search),
|
|
||||||
"llm": latency_stats(latencies_llm),
|
|
||||||
},
|
|
||||||
"samples": samples,
|
|
||||||
"params": {
|
|
||||||
"group_id": group_id,
|
|
||||||
"search_limit": search_limit,
|
|
||||||
"context_char_budget": context_char_budget,
|
|
||||||
"llm_temperature": llm_temperature,
|
|
||||||
"llm_max_tokens": llm_max_tokens,
|
|
||||||
"search_type": search_type,
|
|
||||||
"start_index": start_index,
|
|
||||||
"llm_id": SELECTED_LLM_ID,
|
|
||||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
|
||||||
},
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
await connector.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
load_dotenv()
|
|
||||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
|
||||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
|
||||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
|
||||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
|
||||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
|
||||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
|
||||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
|
||||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
|
||||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
|
||||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
|
||||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
|
||||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
|
||||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
|
||||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
sample_size = 0 if args.all else args.sample_size
|
|
||||||
|
|
||||||
verbose_flag = False if args.quiet else args.verbose
|
|
||||||
result = asyncio.run(
|
|
||||||
run_memsciqa_test(
|
|
||||||
sample_size=sample_size,
|
|
||||||
group_id=args.group_id,
|
|
||||||
search_limit=args.search_limit,
|
|
||||||
context_char_budget=args.context_char_budget,
|
|
||||||
llm_temperature=args.llm_temperature,
|
|
||||||
llm_max_tokens=args.llm_max_tokens,
|
|
||||||
search_type=args.search_type,
|
|
||||||
data_path=args.data_path,
|
|
||||||
start_index=args.start_index,
|
|
||||||
verbose=verbose_flag,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
||||||
|
|
||||||
# 结果保存
|
|
||||||
out_path = args.output
|
|
||||||
if not out_path:
|
|
||||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
|
||||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
|
||||||
try:
|
|
||||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
|
||||||
with open(out_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"\n💾 结果已保存: {out_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ 结果保存失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
except Exception:
|
|
||||||
def load_dotenv():
|
|
||||||
return None
|
|
||||||
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
|
||||||
|
|
||||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
|
||||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
|
||||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
|
||||||
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
dataset: str,
|
|
||||||
sample_size: int,
|
|
||||||
reset_group: bool,
|
|
||||||
group_id: str | None,
|
|
||||||
judge_model: str | None = None,
|
|
||||||
search_limit: int | None = None,
|
|
||||||
context_char_budget: int | None = None,
|
|
||||||
llm_temperature: float | None = None,
|
|
||||||
llm_max_tokens: int | None = None,
|
|
||||||
search_type: str | None = None,
|
|
||||||
start_index: int | None = None,
|
|
||||||
max_contexts_per_item: int | None = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
|
||||||
group_id = group_id or SELECTED_GROUP_ID
|
|
||||||
|
|
||||||
if reset_group:
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
await connector.delete_group(group_id)
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
if dataset == "locomo":
|
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
|
||||||
if search_limit is not None:
|
|
||||||
kwargs["search_limit"] = search_limit
|
|
||||||
if context_char_budget is not None:
|
|
||||||
kwargs["context_char_budget"] = context_char_budget
|
|
||||||
if llm_temperature is not None:
|
|
||||||
kwargs["llm_temperature"] = llm_temperature
|
|
||||||
if llm_max_tokens is not None:
|
|
||||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
|
||||||
if search_type is not None:
|
|
||||||
kwargs["search_type"] = search_type
|
|
||||||
return await run_locomo_eval(**kwargs)
|
|
||||||
|
|
||||||
if dataset == "memsciqa":
|
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
|
||||||
if search_limit is not None:
|
|
||||||
kwargs["search_limit"] = search_limit
|
|
||||||
if context_char_budget is not None:
|
|
||||||
kwargs["context_char_budget"] = context_char_budget
|
|
||||||
if llm_temperature is not None:
|
|
||||||
kwargs["llm_temperature"] = llm_temperature
|
|
||||||
if llm_max_tokens is not None:
|
|
||||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
|
||||||
if search_type is not None:
|
|
||||||
kwargs["search_type"] = search_type
|
|
||||||
return await run_memsciqa_eval(**kwargs)
|
|
||||||
|
|
||||||
if dataset == "longmemeval":
|
|
||||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
|
||||||
if search_limit is not None:
|
|
||||||
kwargs["search_limit"] = search_limit
|
|
||||||
if context_char_budget is not None:
|
|
||||||
kwargs["context_char_budget"] = context_char_budget
|
|
||||||
if llm_temperature is not None:
|
|
||||||
kwargs["llm_temperature"] = llm_temperature
|
|
||||||
if llm_max_tokens is not None:
|
|
||||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
|
||||||
if search_type is not None:
|
|
||||||
kwargs["search_type"] = search_type
|
|
||||||
if start_index is not None:
|
|
||||||
kwargs["start_index"] = start_index
|
|
||||||
if max_contexts_per_item is not None:
|
|
||||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
|
||||||
return await run_longmemeval_test(**kwargs)
|
|
||||||
raise ValueError(f"未知数据集: {dataset}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
load_dotenv()
|
|
||||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
|
||||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
|
||||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
|
||||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
|
||||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
|
||||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
|
||||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
|
||||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
|
||||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
|
||||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
|
||||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
|
||||||
# 仅透传到 longmemeval;其他数据集忽略
|
|
||||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
|
||||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
|
||||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
result = asyncio.run(run(
|
|
||||||
args.dataset,
|
|
||||||
args.sample_size,
|
|
||||||
args.reset_group,
|
|
||||||
args.group_id,
|
|
||||||
args.judge_model,
|
|
||||||
args.search_limit,
|
|
||||||
args.context_char_budget,
|
|
||||||
args.llm_temperature,
|
|
||||||
args.llm_max_tokens,
|
|
||||||
args.search_type,
|
|
||||||
args.start_index,
|
|
||||||
args.max_contexts_per_item,
|
|
||||||
))
|
|
||||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
||||||
|
|
||||||
# 结果输出逻辑保持不变
|
|
||||||
if args.output:
|
|
||||||
out_path = args.output
|
|
||||||
else:
|
|
||||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
|
||||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
|
||||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
|
||||||
|
|
||||||
out_dir = os.path.dirname(out_path)
|
|
||||||
if out_dir and not os.path.exists(out_dir):
|
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
|
||||||
with open(out_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
print(f"\n结果已保存到: {out_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -187,11 +187,11 @@ class ChunkerClient:
|
|||||||
async def generate_chunks(self, dialogue: DialogData):
|
async def generate_chunks(self, dialogue: DialogData):
|
||||||
"""
|
"""
|
||||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||||
|
|
||||||
Each message creates one chunk, directly inheriting role information.
|
Each message creates one chunk, directly inheriting role information.
|
||||||
If a message is too long, it will be split into multiple sub-chunks,
|
If a message is too long, it will be split into multiple sub-chunks,
|
||||||
each maintaining the same speaker.
|
each maintaining the same speaker.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If dialogue has no messages or chunking fails
|
ValueError: If dialogue has no messages or chunking fails
|
||||||
"""
|
"""
|
||||||
@@ -201,9 +201,9 @@ class ChunkerClient:
|
|||||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||||
f"Cannot generate chunks from empty dialogue."
|
f"Cannot generate chunks from empty dialogue."
|
||||||
)
|
)
|
||||||
|
|
||||||
dialogue.chunks = []
|
dialogue.chunks = []
|
||||||
|
|
||||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||||
# Validate message has required attributes
|
# Validate message has required attributes
|
||||||
@@ -212,13 +212,13 @@ class ChunkerClient:
|
|||||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||||
f"missing 'role' or 'msg' attribute"
|
f"missing 'role' or 'msg' attribute"
|
||||||
)
|
)
|
||||||
|
|
||||||
msg_content = msg.msg.strip()
|
msg_content = msg.msg.strip()
|
||||||
|
|
||||||
# Skip empty messages
|
# Skip empty messages
|
||||||
if not msg_content:
|
if not msg_content:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果消息太长,可以进一步分块
|
# 如果消息太长,可以进一步分块
|
||||||
if len(msg_content) > self.chunk_size:
|
if len(msg_content) > self.chunk_size:
|
||||||
# 对单个消息的内容进行分块
|
# 对单个消息的内容进行分块
|
||||||
@@ -228,14 +228,14 @@ class ChunkerClient:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, sub_chunk in enumerate(sub_chunks):
|
for idx, sub_chunk in enumerate(sub_chunks):
|
||||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||||
sub_chunk_text = sub_chunk_text.strip()
|
sub_chunk_text = sub_chunk_text.strip()
|
||||||
|
|
||||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
content=f"{msg.role}: {sub_chunk_text}",
|
content=f"{msg.role}: {sub_chunk_text}",
|
||||||
speaker=msg.role, # 直接继承角色
|
speaker=msg.role, # 直接继承角色
|
||||||
@@ -260,7 +260,7 @@ class ChunkerClient:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
|
|
||||||
# Validate we generated at least one chunk
|
# Validate we generated at least one chunk
|
||||||
if not dialogue.chunks:
|
if not dialogue.chunks:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -268,7 +268,7 @@ class ChunkerClient:
|
|||||||
f"All messages were either empty or too short. "
|
f"All messages were either empty or too short. "
|
||||||
f"Messages count: {len(dialogue.context.msgs)}"
|
f"Messages count: {len(dialogue.context.msgs)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return dialogue
|
return dialogue
|
||||||
|
|
||||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
|||||||
"""Parameters for temporal search queries in the knowledge graph.
|
"""Parameters for temporal search queries in the knowledge graph.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
group_id: Group ID to filter search results (default: 'test')
|
end_user_id: Group ID to filter search results (default: 'test')
|
||||||
apply_id: Application ID to filter search results
|
apply_id: Application ID to filter search results
|
||||||
user_id: User ID to filter search results
|
user_id: User ID to filter search results
|
||||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
|||||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||||
limit: Maximum number of results to return (default: 3)
|
limit: Maximum number of results to return (default: 3)
|
||||||
"""
|
"""
|
||||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||||
|
|||||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
|||||||
id: Unique identifier for the edge
|
id: Unique identifier for the edge
|
||||||
source: ID of the source node
|
source: ID of the source node
|
||||||
target: ID of the target node
|
target: ID of the target node
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
run_id: Unique identifier for the pipeline run that created this edge
|
run_id: Unique identifier for the pipeline run that created this edge
|
||||||
created_at: Timestamp when the edge was created (system perspective)
|
created_at: Timestamp when the edge was created (system perspective)
|
||||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
|||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||||
source: str = Field(..., description="The ID of the source node.")
|
source: str = Field(..., description="The ID of the source node.")
|
||||||
target: str = Field(..., description="The ID of the target node.")
|
target: str = Field(..., description="The ID of the target node.")
|
||||||
group_id: str = Field(..., description="The group ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
user_id: str = Field(..., description="The user ID of the edge.")
|
|
||||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
|||||||
Attributes:
|
Attributes:
|
||||||
id: Unique identifier for the node
|
id: Unique identifier for the node
|
||||||
name: Name of the node
|
name: Name of the node
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
run_id: Unique identifier for the pipeline run that created this node
|
run_id: Unique identifier for the pipeline run that created this node
|
||||||
created_at: Timestamp when the node was created (system perspective)
|
created_at: Timestamp when the node was created (system perspective)
|
||||||
expired_at: Optional timestamp when the node expires (system perspective)
|
expired_at: Optional timestamp when the node expires (system perspective)
|
||||||
"""
|
"""
|
||||||
id: str = Field(..., description="The unique identifier for the node.")
|
id: str = Field(..., description="The unique identifier for the node.")
|
||||||
name: str = Field(..., description="The name of the node.")
|
name: str = Field(..., description="The name of the node.")
|
||||||
group_id: str = Field(..., description="The group ID of the node.")
|
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||||
user_id: str = Field(..., description="The user ID of the edge.")
|
|
||||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class Statement(BaseModel):
|
|||||||
Attributes:
|
Attributes:
|
||||||
id: Unique identifier for the statement
|
id: Unique identifier for the statement
|
||||||
chunk_id: ID of the parent chunk this statement belongs to
|
chunk_id: ID of the parent chunk this statement belongs to
|
||||||
group_id: Optional group ID for multi-tenancy
|
end_user_id: Optional group ID for multi-tenancy
|
||||||
statement: The actual statement text content
|
statement: The actual statement text content
|
||||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||||
statement_embedding: Optional embedding vector for the statement
|
statement_embedding: Optional embedding vector for the statement
|
||||||
@@ -73,7 +73,7 @@ class Statement(BaseModel):
|
|||||||
"""
|
"""
|
||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||||
statement: str = Field(..., description="The text content of the statement.")
|
statement: str = Field(..., description="The text content of the statement.")
|
||||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||||
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
|
|||||||
context: Full conversation context
|
context: Full conversation context
|
||||||
dialog_embedding: Optional embedding vector for the entire dialog
|
dialog_embedding: Optional embedding vector for the entire dialog
|
||||||
ref_id: Reference ID linking to external dialog system
|
ref_id: Reference ID linking to external dialog system
|
||||||
group_id: Group ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
user_id: User ID for user-specific data
|
|
||||||
apply_id: Application ID for application-specific data
|
|
||||||
created_at: Timestamp when the dialog was created
|
created_at: Timestamp when the dialog was created
|
||||||
expired_at: Timestamp when the dialog expires (default: far future)
|
expired_at: Timestamp when the dialog expires (default: far future)
|
||||||
metadata: Additional metadata as key-value pairs
|
metadata: Additional metadata as key-value pairs
|
||||||
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
|
|||||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
|
||||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||||
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def assign_group_id_to_statements(self) -> None:
|
def assign_group_id_to_statements(self) -> None:
|
||||||
"""Assign this dialog's group_id to all statements in all chunks.
|
"""Assign this dialog's end_user_id to all statements in all chunks.
|
||||||
|
|
||||||
This method updates statements that don't have a group_id set.
|
This method updates statements that don't have a end_user_id set.
|
||||||
"""
|
"""
|
||||||
for chunk in self.chunks:
|
for chunk in self.chunks:
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
if statement.group_id is None:
|
if statement.end_user_id is None:
|
||||||
statement.group_id = self.group_id
|
statement.end_user_id = self.end_user_id
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
@@ -396,13 +397,13 @@ def rerank_with_activation(
|
|||||||
return reranked
|
return reranked
|
||||||
|
|
||||||
|
|
||||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||||
"""Log search query information using the logger.
|
"""Log search query information using the logger.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: The search query text
|
query_text: The search query text
|
||||||
search_type: Type of search (keyword, embedding, hybrid)
|
search_type: Type of search (keyword, embedding, hybrid)
|
||||||
group_id: Group identifier for filtering
|
end_user_id: Group identifier for filtering
|
||||||
limit: Maximum number of results
|
limit: Maximum number of results
|
||||||
include: List of result types to include
|
include: List of result types to include
|
||||||
log_file: Deprecated parameter, kept for backward compatibility
|
log_file: Deprecated parameter, kept for backward compatibility
|
||||||
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
|
|||||||
# Log using the standard logger
|
# Log using the standard logger
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||||
f"group_id={group_id}, limit={limit}, include={include}"
|
f"end_user_id={end_user_id}, limit={limit}, include={include}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
|
|||||||
async def run_hybrid_search(
|
async def run_hybrid_search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
group_id: str | None,
|
end_user_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
include: List[str],
|
include: List[str],
|
||||||
output_path: str | None,
|
output_path: str | None,
|
||||||
@@ -715,7 +716,7 @@ async def run_hybrid_search(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Log the search query
|
# Log the search query
|
||||||
log_search_query(query_text, search_type, group_id, limit, include)
|
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
results = {}
|
results = {}
|
||||||
@@ -732,7 +733,7 @@ async def run_hybrid_search(
|
|||||||
search_graph(
|
search_graph(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
q=query_text,
|
q=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include
|
include=include
|
||||||
)
|
)
|
||||||
@@ -769,7 +770,7 @@ async def run_hybrid_search(
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
embedder_client=embedder,
|
embedder_client=embedder,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
)
|
)
|
||||||
@@ -916,9 +917,7 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
|
|
||||||
async def search_by_temporal(
|
async def search_by_temporal(
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -929,7 +928,7 @@ async def search_by_temporal(
|
|||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
|
|
||||||
- Matches statements created between start_date and end_date
|
- Matches statements created between start_date and end_date
|
||||||
- Optionally filters by group_id
|
- Optionally filters by end_user_id
|
||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
@@ -939,9 +938,7 @@ async def search_by_temporal(
|
|||||||
end_date = normalize_date_safe(end_date)
|
end_date = normalize_date_safe(end_date)
|
||||||
|
|
||||||
params = TemporalSearchParams.model_validate({
|
params = TemporalSearchParams.model_validate({
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"apply_id": apply_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"start_date": start_date,
|
"start_date": start_date,
|
||||||
"end_date": end_date,
|
"end_date": end_date,
|
||||||
"valid_date": valid_date,
|
"valid_date": valid_date,
|
||||||
@@ -950,9 +947,7 @@ async def search_by_temporal(
|
|||||||
})
|
})
|
||||||
statements = await search_graph_by_temporal(
|
statements = await search_graph_by_temporal(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
group_id=params.group_id,
|
end_user_id=params.end_user_id,
|
||||||
apply_id=params.apply_id,
|
|
||||||
user_id=params.user_id,
|
|
||||||
start_date=params.start_date,
|
start_date=params.start_date,
|
||||||
end_date=params.end_date,
|
end_date=params.end_date,
|
||||||
valid_date=params.valid_date,
|
valid_date=params.valid_date,
|
||||||
@@ -964,9 +959,7 @@ async def search_by_temporal(
|
|||||||
|
|
||||||
async def search_by_keyword_temporal(
|
async def search_by_keyword_temporal(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
apply_id: Optional[str] = None,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
|
|||||||
invalid_date = normalize_date_safe(invalid_date)
|
invalid_date = normalize_date_safe(invalid_date)
|
||||||
|
|
||||||
params = TemporalSearchParams.model_validate({
|
params = TemporalSearchParams.model_validate({
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"apply_id": apply_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"start_date": start_date,
|
"start_date": start_date,
|
||||||
"end_date": end_date,
|
"end_date": end_date,
|
||||||
"valid_date": valid_date,
|
"valid_date": valid_date,
|
||||||
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
|
|||||||
statements = await search_graph_by_keyword_temporal(
|
statements = await search_graph_by_keyword_temporal(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=params.group_id,
|
end_user_id=params.end_user_id,
|
||||||
apply_id=params.apply_id,
|
|
||||||
user_id=params.user_id,
|
|
||||||
start_date=params.start_date,
|
start_date=params.start_date,
|
||||||
end_date=params.end_date,
|
end_date=params.end_date,
|
||||||
valid_date=params.valid_date,
|
valid_date=params.valid_date,
|
||||||
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
|
|||||||
|
|
||||||
async def search_chunk_by_chunk_id(
|
async def search_chunk_by_chunk_id(
|
||||||
chunk_id: str,
|
chunk_id: str,
|
||||||
group_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
|
|||||||
chunks = await search_graph_by_chunk_id(
|
chunks = await search_graph_by_chunk_id(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
chunk_id=chunk_id,
|
chunk_id=chunk_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
return {"chunks": chunks}
|
return {"chunks": chunks}
|
||||||
|
|||||||
@@ -555,8 +555,8 @@ class DataPreprocessor:
|
|||||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||||
|
|
||||||
|
|
||||||
# 获取group_id,如果不存在则生成默认值
|
# 获取end_user_id,如果不存在则生成默认值
|
||||||
group_id = item.get('group_id', f'group_default_{i}')
|
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||||
user_id = item.get('user_id', f'user_default_{i}')
|
user_id = item.get('user_id', f'user_default_{i}')
|
||||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||||
|
|
||||||
@@ -574,7 +574,7 @@ class DataPreprocessor:
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=context,
|
context=context,
|
||||||
ref_id=dialog_id,
|
ref_id=dialog_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
metadata=metadata
|
metadata=metadata
|
||||||
@@ -644,7 +644,7 @@ class DataPreprocessor:
|
|||||||
|
|
||||||
context = ConversationContext(msgs=messages)
|
context = ConversationContext(msgs=messages)
|
||||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||||
group_id = item.get('group_id', f'group_default_{i}')
|
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||||
user_id = item.get('user_id', f'user_default_{i}')
|
user_id = item.get('user_id', f'user_default_{i}')
|
||||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||||
|
|
||||||
@@ -657,7 +657,7 @@ class DataPreprocessor:
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=context,
|
context=context,
|
||||||
ref_id=dialog_id,
|
ref_id=dialog_id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
apply_id=apply_id,
|
apply_id=apply_id,
|
||||||
metadata=metadata
|
metadata=metadata
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ def accurate_match(
|
|||||||
entity_nodes: List[ExtractedEntityNode]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||||
"""
|
"""
|
||||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||||
"""
|
"""
|
||||||
exact_merge_map: Dict[str, Dict] = {}
|
exact_merge_map: Dict[str, Dict] = {}
|
||||||
@@ -210,8 +210,8 @@ def accurate_match(
|
|||||||
for ent in entity_nodes:
|
for ent in entity_nodes:
|
||||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
|
||||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
|
||||||
if key not in canonical_map:
|
if key not in canonical_map:
|
||||||
canonical_map[key] = ent
|
canonical_map[key] = ent
|
||||||
id_redirect[ent.id] = ent.id
|
id_redirect[ent.id] = ent.id
|
||||||
@@ -223,11 +223,11 @@ def accurate_match(
|
|||||||
id_redirect[ent.id] = canonical.id
|
id_redirect[ent.id] = canonical.id
|
||||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||||
try:
|
try:
|
||||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||||
if k not in exact_merge_map:
|
if k not in exact_merge_map:
|
||||||
exact_merge_map[k] = {
|
exact_merge_map[k] = {
|
||||||
"canonical_id": canonical.id,
|
"canonical_id": canonical.id,
|
||||||
"group_id": canonical.group_id,
|
"end_user_id": canonical.end_user_id,
|
||||||
"name": canonical.name,
|
"name": canonical.name,
|
||||||
"entity_type": canonical.entity_type,
|
"entity_type": canonical.entity_type,
|
||||||
"merged_ids": set(),
|
"merged_ids": set(),
|
||||||
@@ -596,7 +596,7 @@ def fuzzy_match(
|
|||||||
b = deduped_entities[j]
|
b = deduped_entities[j]
|
||||||
|
|
||||||
# 跳过不同业务组的实体
|
# 跳过不同业务组的实体
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -671,7 +671,7 @@ def fuzzy_match(
|
|||||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||||
fuzzy_merge_records.append(
|
fuzzy_merge_records.append(
|
||||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
|
||||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
|||||||
# 记录 LLM 融合日志
|
# 记录 LLM 融合日志
|
||||||
try:
|
try:
|
||||||
llm_records.append(
|
llm_records.append(
|
||||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||||
)
|
)
|
||||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
|
|||||||
id_redirect[k] = a.id
|
id_redirect[k] = a.id
|
||||||
try:
|
try:
|
||||||
disamb_records.append(
|
disamb_records.append(
|
||||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ async def _judge_pair(
|
|||||||
pass
|
pass
|
||||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||||
ctx = {
|
ctx = {
|
||||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"name_text_sim": name_text_sim,
|
"name_text_sim": name_text_sim,
|
||||||
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
ctx = {
|
ctx = {
|
||||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||||
"name_text_sim": name_text_sim,
|
"name_text_sim": name_text_sim,
|
||||||
"name_embed_sim": name_embed_sim,
|
"name_embed_sim": name_embed_sim,
|
||||||
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
|||||||
a = entity_nodes[i]
|
a = entity_nodes[i]
|
||||||
for j in range(i + 1, len(entity_nodes)):
|
for j in range(i + 1, len(entity_nodes)):
|
||||||
b = entity_nodes[j]
|
b = entity_nodes[j]
|
||||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
# 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复)
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
continue
|
continue
|
||||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||||
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
- max_rounds: upper bound for iterative passes (default 3)
|
- max_rounds: upper bound for iterative passes (default 3)
|
||||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||||
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
|
|
||||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||||
"""
|
"""
|
||||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: 实体节点列表
|
nodes: 实体节点列表
|
||||||
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
"""
|
"""
|
||||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||||
for e in nodes:
|
for e in nodes:
|
||||||
gid = getattr(e, "group_id", None)
|
gid = getattr(e, "end_user_id", None)
|
||||||
groups.setdefault(str(gid), []).append(e)
|
groups.setdefault(str(gid), []).append(e)
|
||||||
blocks: List[List[ExtractedEntityNode]] = []
|
blocks: List[List[ExtractedEntityNode]] = []
|
||||||
for gid, arr in groups.items():
|
for gid, arr in groups.items():
|
||||||
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
|||||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||||
current_nodes = _collapse_nodes(current_nodes)
|
current_nodes = _collapse_nodes(current_nodes)
|
||||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
# 步骤2:分块(按end_user_id分块,避免跨组处理)
|
||||||
blocks = _partition_blocks(current_nodes)
|
blocks = _partition_blocks(current_nodes)
|
||||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||||
break
|
break
|
||||||
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
|
|||||||
a = entity_nodes[i]
|
a = entity_nodes[i]
|
||||||
b = entity_nodes[j]
|
b = entity_nodes[j]
|
||||||
# 必须同组
|
# 必须同组
|
||||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||||
continue
|
continue
|
||||||
ta = getattr(a, "entity_type", None)
|
ta = getattr(a, "entity_type", None)
|
||||||
tb = getattr(b, "entity_type", None)
|
tb = getattr(b, "entity_type", None)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
return ExtractedEntityNode(
|
return ExtractedEntityNode(
|
||||||
id=row.get("id"),
|
id=row.get("id"),
|
||||||
name=row.get("name") or "",
|
name=row.get("name") or "",
|
||||||
group_id=row.get("group_id") or "",
|
end_user_id=row.get("end_user_id") or "",
|
||||||
user_id=row.get("user_id") or "",
|
user_id=row.get("user_id") or "",
|
||||||
apply_id=row.get("apply_id") or "",
|
apply_id=row.get("apply_id") or "",
|
||||||
created_at=_parse_dt(row.get("created_at")),
|
created_at=_parse_dt(row.get("created_at")),
|
||||||
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
|
|
||||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||||
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
|||||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||||
"""
|
"""
|
||||||
第二层去重消歧:
|
第二层去重消歧:
|
||||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
|
||||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
|||||||
|
|
||||||
]
|
]
|
||||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||||
connector=connector, group_id=group_id,
|
connector=connector, end_user_id=end_user_id,
|
||||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
if pipeline_config is None:
|
if pipeline_config is None:
|
||||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||||
|
|
||||||
# 先探测 group_id,决定报告写入策略
|
# 先探测 end_user_id,决定报告写入策略
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
for dd in dialog_data_list:
|
for dd in dialog_data_list:
|
||||||
group_id = getattr(dd, "group_id", None)
|
end_user_id = getattr(dd, "end_user_id", None)
|
||||||
if group_id:
|
if end_user_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 第一层去重消歧
|
# 第一层去重消歧
|
||||||
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
|
|
||||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||||
try:
|
try:
|
||||||
if group_id:
|
if end_user_id:
|
||||||
if connector:
|
if connector:
|
||||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
entity_nodes=dedup_entity_nodes,
|
entity_nodes=dedup_entity_nodes,
|
||||||
statement_entity_edges=dedup_statement_entity_edges,
|
statement_entity_edges=dedup_statement_entity_edges,
|
||||||
entity_entity_edges=dedup_entity_entity_edges,
|
entity_entity_edges=dedup_entity_entity_edges,
|
||||||
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
else:
|
else:
|
||||||
print("Skip second-layer dedup: missing connector")
|
print("Skip second-layer dedup: missing connector")
|
||||||
else:
|
else:
|
||||||
print("Skip second-layer dedup: missing group_id")
|
print("Skip second-layer dedup: missing end_user_id")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Second-layer dedup failed: {e}")
|
print(f"Second-layer dedup failed: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
|
|||||||
for d_idx, dialog in enumerate(dialog_data_list):
|
for d_idx, dialog in enumerate(dialog_data_list):
|
||||||
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
||||||
for c_idx, chunk in enumerate(dialog.chunks):
|
for c_idx, chunk in enumerate(dialog.chunks):
|
||||||
all_chunks.append((chunk, dialog.group_id, dialogue_content))
|
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
|
||||||
chunk_metadata.append((d_idx, c_idx))
|
chunk_metadata.append((d_idx, c_idx))
|
||||||
|
|
||||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||||
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
|
|||||||
# 全局并行处理所有分块
|
# 全局并行处理所有分块
|
||||||
async def extract_for_chunk(chunk_data, chunk_index):
|
async def extract_for_chunk(chunk_data, chunk_index):
|
||||||
nonlocal completed_chunks
|
nonlocal completed_chunks
|
||||||
chunk, group_id, dialogue_content = chunk_data
|
chunk, end_user_id, dialogue_content = chunk_data
|
||||||
try:
|
try:
|
||||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
|
||||||
|
|
||||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||||
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
|
|||||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||||
config_id = dialog_data_list[0].config_id
|
config_id = dialog_data_list[0].config_id
|
||||||
|
|
||||||
# 加载DataConfig
|
# 加载MemoryConfig
|
||||||
data_config = None
|
memory_config = None
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
from app.db import SessionLocal
|
from app.db import SessionLocal
|
||||||
from app.repositories.data_config_repository import DataConfigRepository
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
if data_config and not data_config.emotion_enabled:
|
if memory_config and not memory_config.emotion_enabled:
|
||||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||||
return [{} for _ in dialog_data_list]
|
return [{} for _ in dialog_data_list]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
|
||||||
return [{} for _ in dialog_data_list]
|
return [{} for _ in dialog_data_list]
|
||||||
else:
|
else:
|
||||||
logger.info("未找到config_id,跳过情绪提取")
|
logger.info("未找到config_id,跳过情绪提取")
|
||||||
return [{} for _ in dialog_data_list]
|
return [{} for _ in dialog_data_list]
|
||||||
|
|
||||||
# 如果配置未启用情绪提取,直接返回空映射
|
# 如果配置未启用情绪提取,直接返回空映射
|
||||||
if not data_config or not data_config.emotion_enabled:
|
if not memory_config or not memory_config.emotion_enabled:
|
||||||
logger.info("情绪提取未启用,跳过")
|
logger.info("情绪提取未启用,跳过")
|
||||||
return [{} for _ in dialog_data_list]
|
return [{} for _ in dialog_data_list]
|
||||||
|
|
||||||
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
|
|||||||
total_statements += 1
|
total_statements += 1
|
||||||
# 只处理用户的陈述句 (role 为 "user")
|
# 只处理用户的陈述句 (role 为 "user")
|
||||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||||
all_statements.append((statement, data_config))
|
all_statements.append((statement, memory_config))
|
||||||
statement_metadata.append((d_idx, statement.id))
|
statement_metadata.append((d_idx, statement.id))
|
||||||
filtered_statements += 1
|
filtered_statements += 1
|
||||||
|
|
||||||
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
|
|||||||
# 初始化情绪提取服务
|
# 初始化情绪提取服务
|
||||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||||
emotion_service = EmotionExtractionService(
|
emotion_service = EmotionExtractionService(
|
||||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# 全局并行处理所有陈述句
|
# 全局并行处理所有陈述句
|
||||||
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
|
|||||||
id=dialog_data.id,
|
id=dialog_data.id,
|
||||||
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
||||||
ref_id=dialog_data.ref_id,
|
ref_id=dialog_data.ref_id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
content=dialog_data.context.content if dialog_data.context else "",
|
content=dialog_data.context.content if dialog_data.context else "",
|
||||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||||
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
|
|||||||
id=chunk.id,
|
id=chunk.id,
|
||||||
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
||||||
dialog_id=dialog_data.id,
|
dialog_id=dialog_data.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
content=chunk.content,
|
content=chunk.content,
|
||||||
chunk_embedding=chunk.chunk_embedding,
|
chunk_embedding=chunk.chunk_embedding,
|
||||||
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
|
|||||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||||
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edge = StatementChunkEdge(
|
statement_chunk_edge = StatementChunkEdge(
|
||||||
source=statement.id,
|
source=statement.id,
|
||||||
target=chunk.id,
|
target=chunk.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
)
|
)
|
||||||
@@ -1072,13 +1064,16 @@ class ExtractionOrchestrator:
|
|||||||
if statement.triplet_extraction_info:
|
if statement.triplet_extraction_info:
|
||||||
triplet_info = statement.triplet_extraction_info
|
triplet_info = statement.triplet_extraction_info
|
||||||
|
|
||||||
# 创建实体索引到ID的映射
|
# 创建实体索引到ID的映射(支持多种索引方式)
|
||||||
entity_idx_to_id = {}
|
entity_idx_to_id = {}
|
||||||
|
|
||||||
# 创建实体节点
|
# 创建实体节点
|
||||||
for entity_idx, entity in enumerate(triplet_info.entities):
|
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||||
# 映射实体索引到实体ID
|
# 映射实体索引到实体ID(使用多个键以提高容错性)
|
||||||
|
# 1. 使用实体自己的 entity_idx
|
||||||
entity_idx_to_id[entity.entity_idx] = entity.id
|
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||||
|
# 2. 使用枚举索引(从0开始)
|
||||||
|
entity_idx_to_id[entity_idx] = entity.id
|
||||||
|
|
||||||
if entity.id not in entity_id_set:
|
if entity.id not in entity_id_set:
|
||||||
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
|
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
|
||||||
@@ -1095,9 +1090,7 @@ class ExtractionOrchestrator:
|
|||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
@@ -1112,9 +1105,7 @@ class ExtractionOrchestrator:
|
|||||||
source=statement.id,
|
source=statement.id,
|
||||||
target=entity.id,
|
target=entity.id,
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
)
|
)
|
||||||
@@ -1134,9 +1125,7 @@ class ExtractionOrchestrator:
|
|||||||
relation_type=triplet.predicate,
|
relation_type=triplet.predicate,
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
source_statement_id=statement.id,
|
source_statement_id=statement.id,
|
||||||
group_id=dialog_data.group_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
user_id=dialog_data.user_id,
|
|
||||||
apply_id=dialog_data.apply_id,
|
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
@@ -1163,9 +1152,18 @@ class ExtractionOrchestrator:
|
|||||||
relationship_result
|
relationship_result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
# 改进的警告信息,包含更多调试信息
|
||||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
missing_subject = "subject" if not subject_entity_id else ""
|
||||||
f"object_id={triplet.object_id}, statement_id={statement.id}"
|
missing_object = "object" if not object_entity_id else ""
|
||||||
|
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||||
|
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||||
|
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||||
|
f"predicate={triplet.predicate}, "
|
||||||
|
f"statement_id={statement.id}, "
|
||||||
|
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1763,14 +1761,14 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从测试数据生成分块对话
|
"""从测试数据生成分块对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
indices: 要处理的数据索引列表(可选)
|
indices: 要处理的数据索引列表(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1834,7 +1832,7 @@ async def get_chunked_dialogs(
|
|||||||
dialog_data = DialogData(
|
dialog_data = DialogData(
|
||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=data['id'],
|
ref_id=data['id'],
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
metadata=dialog_metadata,
|
metadata=dialog_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1936,7 +1934,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
|||||||
|
|
||||||
async def get_chunked_dialogs_with_preprocessing(
|
async def get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
group_id: str = "default",
|
end_user_id: str = "default",
|
||||||
user_id: str = "default",
|
user_id: str = "default",
|
||||||
apply_id: str = "default",
|
apply_id: str = "default",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
@@ -1948,7 +1946,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunker_strategy: 分块策略
|
chunker_strategy: 分块策略
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
apply_id: 应用ID
|
apply_id: 应用ID
|
||||||
indices: 要处理的数据索引列表
|
indices: 要处理的数据索引列表
|
||||||
@@ -1976,11 +1974,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
indices=indices,
|
indices=indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置 group_id, user_id, apply_id
|
# 设置 end_user_id
|
||||||
for dd in preprocessed_data:
|
for dd in preprocessed_data:
|
||||||
dd.group_id = group_id
|
dd.end_user_id = end_user_id
|
||||||
dd.user_id = user_id
|
|
||||||
dd.apply_id = apply_id
|
|
||||||
|
|
||||||
# 步骤2: 语义剪枝
|
# 步骤2: 语义剪枝
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -193,9 +193,9 @@ async def _process_chunk_summary(
|
|||||||
node = MemorySummaryNode(
|
node = MemorySummaryNode(
|
||||||
id=uuid4().hex,
|
id=uuid4().hex,
|
||||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||||
group_id=dialog.group_id,
|
end_user_id=dialog.end_user_id,
|
||||||
user_id=dialog.user_id,
|
user_id=dialog.end_user_id,
|
||||||
apply_id=dialog.apply_id,
|
apply_id=dialog.end_user_id,
|
||||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
expired_at=datetime(9999, 12, 31),
|
expired_at=datetime(9999, 12, 31),
|
||||||
|
|||||||
@@ -82,12 +82,12 @@ class StatementExtractor:
|
|||||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||||
"""Process a single chunk and return extracted statements
|
"""Process a single chunk and return extracted statements
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk: Chunk object to process
|
chunk: Chunk object to process
|
||||||
group_id: Group ID to assign to all statements in this chunk
|
end_user_id: Group ID to assign to all statements in this chunk
|
||||||
dialogue_content: Full dialogue content to provide as context
|
dialogue_content: Full dialogue content to provide as context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -158,7 +158,7 @@ class StatementExtractor:
|
|||||||
temporal_info=temporal_type,
|
temporal_info=temporal_type,
|
||||||
relevence_info=relevence_info,
|
relevence_info=relevence_info,
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
speaker=chunk_speaker,
|
speaker=chunk_speaker,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -184,10 +184,10 @@ class StatementExtractor:
|
|||||||
|
|
||||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||||
|
|
||||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
|
||||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
|
||||||
return_exceptions=True
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,7 +225,7 @@ class StatementExtractor:
|
|||||||
for i, statement in enumerate(statements, 1):
|
for i, statement in enumerate(statements, 1):
|
||||||
f.write(f"Statement {i}:\n")
|
f.write(f"Statement {i}:\n")
|
||||||
f.write(f"Id: {statement.id}\n")
|
f.write(f"Id: {statement.id}\n")
|
||||||
f.write(f"Group Id: {statement.group_id}\n")
|
f.write(f"Group Id: {statement.end_user_id}\n")
|
||||||
f.write(f"Content: {statement.statement}\n")
|
f.write(f"Content: {statement.statement}\n")
|
||||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||||
@@ -298,7 +298,7 @@ class StatementExtractor:
|
|||||||
|
|
||||||
dialog_sections.append({
|
dialog_sections.append({
|
||||||
"dialog_id": dialog.ref_id,
|
"dialog_id": dialog.ref_id,
|
||||||
"group_id": dialog.group_id,
|
"end_user_id": dialog.end_user_id,
|
||||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||||
"strong": strong_relations,
|
"strong": strong_relations,
|
||||||
"weak": weak_relations,
|
"weak": weak_relations,
|
||||||
@@ -312,7 +312,7 @@ class StatementExtractor:
|
|||||||
for idx, section in enumerate(dialog_sections, 1):
|
for idx, section in enumerate(dialog_sections, 1):
|
||||||
f.write(f"Dialog {idx}:\n")
|
f.write(f"Dialog {idx}:\n")
|
||||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
|
||||||
f.write("Content:\n")
|
f.write("Content:\n")
|
||||||
f.write(f"{section.get('content', '')}\n")
|
f.write(f"{section.get('content', '')}\n")
|
||||||
f.write("-" * 40 + "\n\n")
|
f.write("-" * 40 + "\n\n")
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class TemporalExtractor:
|
|||||||
prompt_logger.info("")
|
prompt_logger.info("")
|
||||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||||
prompt_logger.info(
|
prompt_logger.info(
|
||||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class TripletExtractor:
|
|||||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||||
try:
|
try:
|
||||||
prompt_logger.info(
|
prompt_logger.info(
|
||||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_time: Optional[datetime] = None
|
current_time: Optional[datetime] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -91,7 +91,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||||
group_id: 组ID(可选,用于过滤)
|
end_user_id: 组ID(可选,用于过滤)
|
||||||
current_time: 当前时间(可选,默认使用系统时间)
|
current_time: 当前时间(可选,默认使用系统时间)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -123,7 +123,7 @@ class AccessHistoryManager:
|
|||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
# 步骤1:读取当前节点状态
|
# 步骤1:读取当前节点状态
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
|
|
||||||
if not node_data:
|
if not node_data:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -142,7 +142,7 @@ class AccessHistoryManager:
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
update_data=update_data,
|
update_data=update_data,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -172,7 +172,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_ids: List[str],
|
node_ids: List[str],
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_time: Optional[datetime] = None
|
current_time: Optional[datetime] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -184,7 +184,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_ids: 节点ID列表
|
node_ids: 节点ID列表
|
||||||
node_label: 节点标签(所有节点必须是同一类型)
|
node_label: 节点标签(所有节点必须是同一类型)
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
current_time: 当前时间(可选)
|
current_time: 当前时间(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -202,7 +202,7 @@ class AccessHistoryManager:
|
|||||||
task = self.record_access(
|
task = self.record_access(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
current_time=current_time
|
current_time=current_time
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
@@ -235,7 +235,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
检查节点数据的一致性
|
检查节点数据的一致性
|
||||||
@@ -249,14 +249,14 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||||
- 一致性检查结果枚举
|
- 一致性检查结果枚举
|
||||||
- 错误描述(如果不一致)
|
- 错误描述(如果不一致)
|
||||||
"""
|
"""
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
|
|
||||||
if not node_data:
|
if not node_data:
|
||||||
return ConsistencyCheckResult.CONSISTENT, None
|
return ConsistencyCheckResult.CONSISTENT, None
|
||||||
@@ -305,7 +305,7 @@ class AccessHistoryManager:
|
|||||||
async def check_batch_consistency(
|
async def check_batch_consistency(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1000
|
limit: int = 1000
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -313,7 +313,7 @@ class AccessHistoryManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
limit: 检查的最大节点数
|
limit: 检查的最大节点数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -329,16 +329,16 @@ class AccessHistoryManager:
|
|||||||
MATCH (n:{node_label})
|
MATCH (n:{node_label})
|
||||||
WHERE n.access_history IS NOT NULL
|
WHERE n.access_history IS NOT NULL
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {"limit": limit}
|
params = {"limit": limit}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params["group_id"] = group_id
|
params["end_user_id"] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
node_ids = [r['id'] for r in results]
|
node_ids = [r['id'] for r in results]
|
||||||
@@ -351,7 +351,7 @@ class AccessHistoryManager:
|
|||||||
result, message = await self.check_consistency(
|
result, message = await self.check_consistency(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == ConsistencyCheckResult.CONSISTENT:
|
if result == ConsistencyCheckResult.CONSISTENT:
|
||||||
@@ -387,7 +387,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
自动修复节点的数据不一致问题
|
自动修复节点的数据不一致问题
|
||||||
@@ -401,7 +401,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 修复成功返回True,否则返回False
|
bool: 修复成功返回True,否则返回False
|
||||||
@@ -411,7 +411,7 @@ class AccessHistoryManager:
|
|||||||
result, message = await self.check_consistency(
|
result, message = await self.check_consistency(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == ConsistencyCheckResult.CONSISTENT:
|
if result == ConsistencyCheckResult.CONSISTENT:
|
||||||
@@ -419,7 +419,7 @@ class AccessHistoryManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 获取节点数据
|
# 获取节点数据
|
||||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||||
if not node_data:
|
if not node_data:
|
||||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||||
return False
|
return False
|
||||||
@@ -457,8 +457,8 @@ class AccessHistoryManager:
|
|||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " WHERE n.group_id = $group_id"
|
query += " WHERE n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
SET n += $repair_data
|
SET n += $repair_data
|
||||||
RETURN n
|
RETURN n
|
||||||
@@ -468,8 +468,8 @@ class AccessHistoryManager:
|
|||||||
'node_id': node_id,
|
'node_id': node_id,
|
||||||
'repair_data': repair_data
|
'repair_data': repair_data
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
await self.connector.execute_query(query, **params)
|
await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -491,7 +491,7 @@ class AccessHistoryManager:
|
|||||||
self,
|
self,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取节点数据
|
获取节点数据
|
||||||
@@ -499,7 +499,7 @@ class AccessHistoryManager:
|
|||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||||
@@ -507,8 +507,8 @@ class AccessHistoryManager:
|
|||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " WHERE n.group_id = $group_id"
|
query += " WHERE n.end_user_id = $end_user_id"
|
||||||
query += """
|
query += """
|
||||||
RETURN n.id as id,
|
RETURN n.id as id,
|
||||||
n.importance_score as importance_score,
|
n.importance_score as importance_score,
|
||||||
@@ -519,8 +519,8 @@ class AccessHistoryManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {'node_id': node_id}
|
params = {'node_id': node_id}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -585,7 +585,7 @@ class AccessHistoryManager:
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
update_data: Dict[str, Any],
|
update_data: Dict[str, Any],
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
原子性更新节点(使用乐观锁)
|
原子性更新节点(使用乐观锁)
|
||||||
@@ -597,7 +597,7 @@ class AccessHistoryManager:
|
|||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
node_label: 节点标签
|
node_label: 节点标签
|
||||||
update_data: 更新数据
|
update_data: 更新数据
|
||||||
group_id: 组ID(可选)
|
end_user_id: 组ID(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 更新后的节点数据
|
Dict[str, Any]: 更新后的节点数据
|
||||||
@@ -606,13 +606,13 @@ class AccessHistoryManager:
|
|||||||
RuntimeError: 如果更新失败或发生版本冲突
|
RuntimeError: 如果更新失败或发生版本冲突
|
||||||
"""
|
"""
|
||||||
# 定义事务函数
|
# 定义事务函数
|
||||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||||
# 步骤1:读取当前节点并获取版本号
|
# 步骤1:读取当前节点并获取版本号
|
||||||
read_query = f"""
|
read_query = f"""
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
"""
|
"""
|
||||||
if group_id:
|
if end_user_id:
|
||||||
read_query += " WHERE n.group_id = $group_id"
|
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||||
read_query += """
|
read_query += """
|
||||||
RETURN n.id as id,
|
RETURN n.id as id,
|
||||||
n.version as version,
|
n.version as version,
|
||||||
@@ -624,8 +624,8 @@ class AccessHistoryManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
read_params = {'node_id': node_id}
|
read_params = {'node_id': node_id}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
read_params['group_id'] = group_id
|
read_params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
read_result = await tx.run(read_query, **read_params)
|
read_result = await tx.run(read_query, **read_params)
|
||||||
current_node = await read_result.single()
|
current_node = await read_result.single()
|
||||||
@@ -656,8 +656,8 @@ class AccessHistoryManager:
|
|||||||
|
|
||||||
# 构建 WHERE 子句
|
# 构建 WHERE 子句
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
if group_id:
|
if end_user_id:
|
||||||
where_conditions.append("n.group_id = $group_id")
|
where_conditions.append("n.end_user_id = $end_user_id")
|
||||||
|
|
||||||
# 添加版本检查
|
# 添加版本检查
|
||||||
if current_version > 0:
|
if current_version > 0:
|
||||||
@@ -695,8 +695,8 @@ class AccessHistoryManager:
|
|||||||
'last_access_time': update_data['last_access_time'],
|
'last_access_time': update_data['last_access_time'],
|
||||||
'access_count': update_data['access_count']
|
'access_count': update_data['access_count']
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
update_params['group_id'] = group_id
|
update_params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
update_result = await tx.run(update_query, **update_params)
|
update_result = await tx.run(update_query, **update_params)
|
||||||
updated_node = await update_result.single()
|
updated_node = await update_result.single()
|
||||||
@@ -720,7 +720,7 @@ class AccessHistoryManager:
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
update_data=update_data,
|
update_data=update_data,
|
||||||
group_id=group_id
|
end_user_id=end_user_id
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -11,9 +11,10 @@ Functions:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
from uuid import UUID
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.repositories.data_config_repository import DataConfigRepository
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||||
|
|
||||||
|
|
||||||
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
|||||||
|
|
||||||
def load_actr_config_from_db(
|
def load_actr_config_from_db(
|
||||||
db: Session,
|
db: Session,
|
||||||
config_id: Optional[int] = None
|
config_id: Optional[UUID] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
从数据库加载 ACT-R 配置参数
|
从数据库加载 ACT-R 配置参数
|
||||||
|
|
||||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
从 PostgreSQL 的 memory_config 表读取配置参数,
|
||||||
并计算派生参数(如 forgetting_rate)。
|
并计算派生参数(如 forgetting_rate)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
|
|||||||
|
|
||||||
# 从数据库加载配置
|
# 从数据库加载配置
|
||||||
try:
|
try:
|
||||||
repository = DataConfigRepository()
|
repository = MemoryConfigRepository()
|
||||||
db_config = repository.get_by_id(db, config_id)
|
db_config = repository.get_by_id(db, config_id)
|
||||||
|
|
||||||
if db_config is None:
|
if db_config is None:
|
||||||
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
|
|||||||
|
|
||||||
def create_actr_calculator_from_config(
|
def create_actr_calculator_from_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
config_id: Optional[int] = None
|
config_id: Optional[UUID] = None
|
||||||
) -> ACTRCalculator:
|
) -> ACTRCalculator:
|
||||||
"""
|
"""
|
||||||
从数据库配置创建 ACTRCalculator 实例
|
从数据库配置创建 ACTRCalculator 实例
|
||||||
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
|
|||||||
ValueError: 如果指定的 config_id 不存在
|
ValueError: 如果指定的 config_id 不存在
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from sqlalchemy.orm import Session
|
|
||||||
>>> db = Session()
|
|
||||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
|
||||||
>>> # 使用计算器
|
|
||||||
>>> activation = calculator.calculate_memory_activation(...)
|
|
||||||
"""
|
"""
|
||||||
# 加载配置
|
# 加载配置
|
||||||
config = load_actr_config_from_db(db, config_id)
|
config = load_actr_config_from_db(db, config_id)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ Classes:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
from uuid import UUID
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||||
@@ -66,10 +67,10 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
async def run_forgetting_cycle(
|
async def run_forgetting_cycle(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
max_merge_batch_size: int = 100,
|
max_merge_batch_size: int = 100,
|
||||||
min_days_since_access: int = 30,
|
min_days_since_access: int = 30,
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[UUID] = None,
|
||||||
db = None
|
db = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -77,7 +78,7 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||||
config_id: 配置ID(可选,用于获取 llm_id)
|
config_id: 配置ID(可选,用于获取 llm_id)
|
||||||
@@ -107,19 +108,19 @@ class ForgettingScheduler:
|
|||||||
start_time_iso = start_time.isoformat()
|
start_time_iso = start_time.isoformat()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"开始遗忘周期: group_id={group_id}, "
|
f"开始遗忘周期: end_user_id={end_user_id}, "
|
||||||
f"max_batch={max_merge_batch_size}, "
|
f"max_batch={max_merge_batch_size}, "
|
||||||
f"min_days={min_days_since_access}"
|
f"min_days={min_days_since_access}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 步骤1:统计遗忘前的节点数量
|
# 步骤1:统计遗忘前的节点数量
|
||||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
nodes_before = await self._count_knowledge_nodes(end_user_id)
|
||||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||||
|
|
||||||
# 步骤2:识别可遗忘的节点对
|
# 步骤2:识别可遗忘的节点对
|
||||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
min_days_since_access=min_days_since_access
|
min_days_since_access=min_days_since_access
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,7 +214,7 @@ class ForgettingScheduler:
|
|||||||
'statement_text': pair['statement_text'],
|
'statement_text': pair['statement_text'],
|
||||||
'statement_activation': pair['statement_activation'],
|
'statement_activation': pair['statement_activation'],
|
||||||
'statement_importance': pair['statement_importance'],
|
'statement_importance': pair['statement_importance'],
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
entity_node = {
|
entity_node = {
|
||||||
@@ -222,7 +223,7 @@ class ForgettingScheduler:
|
|||||||
'entity_type': pair['entity_type'],
|
'entity_type': pair['entity_type'],
|
||||||
'entity_activation': pair['entity_activation'],
|
'entity_activation': pair['entity_activation'],
|
||||||
'entity_importance': pair['entity_importance'],
|
'entity_importance': pair['entity_importance'],
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# 融合节点
|
# 融合节点
|
||||||
@@ -262,7 +263,7 @@ class ForgettingScheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 步骤6:统计遗忘后的节点数量
|
# 步骤6:统计遗忘后的节点数量
|
||||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
nodes_after = await self._count_knowledge_nodes(end_user_id)
|
||||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||||
|
|
||||||
# 步骤7:生成遗忘报告
|
# 步骤7:生成遗忘报告
|
||||||
@@ -315,7 +316,7 @@ class ForgettingScheduler:
|
|||||||
|
|
||||||
async def _count_knowledge_nodes(
|
async def _count_knowledge_nodes(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
统计知识层节点总数
|
统计知识层节点总数
|
||||||
@@ -323,7 +324,7 @@ class ForgettingScheduler:
|
|||||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 知识层节点总数
|
int: 知识层节点总数
|
||||||
@@ -333,16 +334,16 @@ class ForgettingScheduler:
|
|||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND n.group_id = $group_id"
|
query += " AND n.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
RETURN count(n) as total
|
RETURN count(n) as total
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {}
|
params = {}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ Classes:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
from uuid import UUID
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -90,7 +91,7 @@ class ForgettingStrategy:
|
|||||||
|
|
||||||
async def find_forgettable_nodes(
|
async def find_forgettable_nodes(
|
||||||
self,
|
self,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
min_days_since_access: int = 30
|
min_days_since_access: int = 30
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -102,7 +103,7 @@ class ForgettingStrategy:
|
|||||||
3. Statement 和 Entity 之间存在关系边
|
3. Statement 和 Entity 之间存在关系边
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -136,8 +137,8 @@ class ForgettingStrategy:
|
|||||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if group_id:
|
if end_user_id:
|
||||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
|
||||||
|
|
||||||
query += """
|
query += """
|
||||||
RETURN s.id as statement_id,
|
RETURN s.id as statement_id,
|
||||||
@@ -159,8 +160,8 @@ class ForgettingStrategy:
|
|||||||
'threshold': self.forgetting_threshold,
|
'threshold': self.forgetting_threshold,
|
||||||
'cutoff_time': cutoff_time_iso
|
'cutoff_time': cutoff_time_iso
|
||||||
}
|
}
|
||||||
if group_id:
|
if end_user_id:
|
||||||
params['group_id'] = group_id
|
params['end_user_id'] = end_user_id
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
|
|
||||||
@@ -176,7 +177,7 @@ class ForgettingStrategy:
|
|||||||
self,
|
self,
|
||||||
statement_node: Dict[str, Any],
|
statement_node: Dict[str, Any],
|
||||||
entity_node: Dict[str, Any],
|
entity_node: Dict[str, Any],
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[UUID] = None,
|
||||||
db = None
|
db = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -247,8 +248,8 @@ class ForgettingStrategy:
|
|||||||
entity_activation = entity_node['entity_activation']
|
entity_activation = entity_node['entity_activation']
|
||||||
entity_importance = entity_node['entity_importance']
|
entity_importance = entity_node['entity_importance']
|
||||||
|
|
||||||
# 获取 group_id(从 statement 或 entity 节点)
|
# 获取 end_user_id(从 statement 或 entity 节点)
|
||||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
|
||||||
|
|
||||||
# 生成摘要内容
|
# 生成摘要内容
|
||||||
summary_text = await self._generate_summary(
|
summary_text = await self._generate_summary(
|
||||||
@@ -325,7 +326,7 @@ class ForgettingStrategy:
|
|||||||
last_access_time: $current_time,
|
last_access_time: $current_time,
|
||||||
access_count: 1,
|
access_count: 1,
|
||||||
version: 1,
|
version: 1,
|
||||||
group_id: $group_id,
|
end_user_id: $end_user_id,
|
||||||
created_at: datetime($current_time),
|
created_at: datetime($current_time),
|
||||||
merged_at: datetime($current_time)
|
merged_at: datetime($current_time)
|
||||||
})
|
})
|
||||||
@@ -423,7 +424,7 @@ class ForgettingStrategy:
|
|||||||
'inherited_activation': inherited_activation,
|
'inherited_activation': inherited_activation,
|
||||||
'inherited_importance': inherited_importance,
|
'inherited_importance': inherited_importance,
|
||||||
'current_time': current_time_iso,
|
'current_time': current_time_iso,
|
||||||
'group_id': group_id
|
'end_user_id': end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -462,7 +463,7 @@ class ForgettingStrategy:
|
|||||||
statement_text: str,
|
statement_text: str,
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
entity_type: str,
|
entity_type: str,
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[UUID] = None,
|
||||||
db = None
|
db = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -527,7 +528,7 @@ class ForgettingStrategy:
|
|||||||
statement_text, entity_name, entity_type
|
statement_text, entity_name, entity_type
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_llm_client(self, db, config_id: int):
|
async def _get_llm_client(self, db, config_id: UUID):
|
||||||
"""
|
"""
|
||||||
从数据库获取 LLM 客户端
|
从数据库获取 LLM 客户端
|
||||||
|
|
||||||
@@ -539,11 +540,11 @@ class ForgettingStrategy:
|
|||||||
LLM 客户端实例,如果无法获取则返回 None
|
LLM 客户端实例,如果无法获取则返回 None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.repositories.data_config_repository import DataConfigRepository
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
# 从数据库读取配置
|
# 从数据库读取配置
|
||||||
repository = DataConfigRepository()
|
repository = MemoryConfigRepository()
|
||||||
db_config = repository.get_by_id(db, config_id)
|
db_config = repository.get_by_id(db, config_id)
|
||||||
|
|
||||||
if db_config is None or db_config.llm_id is None:
|
if db_config is None or db_config.llm_id is None:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ __all__ = [
|
|||||||
async def run_hybrid_search(
|
async def run_hybrid_search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
group_id: str | None = None,
|
end_user_id: str | None = None,
|
||||||
apply_id: str | None = None,
|
apply_id: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
@@ -54,7 +54,7 @@ async def run_hybrid_search(
|
|||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||||
group_id: 组ID过滤
|
end_user_id: 组ID过滤
|
||||||
apply_id: 应用ID过滤
|
apply_id: 应用ID过滤
|
||||||
user_id: 用户ID过滤
|
user_id: 用户ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
@@ -104,7 +104,7 @@ async def run_hybrid_search(
|
|||||||
# 执行搜索
|
# 执行搜索
|
||||||
result = await strategy.search(
|
result = await strategy.search(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include,
|
include=include,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
|
|||||||
@@ -77,7 +77,7 @@
|
|||||||
# async def search(
|
# async def search(
|
||||||
# self,
|
# self,
|
||||||
# query_text: str,
|
# query_text: str,
|
||||||
# group_id: Optional[str] = None,
|
# end_user_id: Optional[str] = None,
|
||||||
# limit: int = 50,
|
# limit: int = 50,
|
||||||
# include: Optional[List[str]] = None,
|
# include: Optional[List[str]] = None,
|
||||||
# **kwargs
|
# **kwargs
|
||||||
@@ -86,7 +86,7 @@
|
|||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# query_text: 查询文本
|
# query_text: 查询文本
|
||||||
# group_id: 可选的组ID过滤
|
# end_user_id: 可选的组ID过滤
|
||||||
# limit: 每个类别的最大结果数
|
# limit: 每个类别的最大结果数
|
||||||
# include: 要包含的搜索类别列表
|
# include: 要包含的搜索类别列表
|
||||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||||
@@ -94,7 +94,7 @@
|
|||||||
# Returns:
|
# Returns:
|
||||||
# SearchResult: 搜索结果对象
|
# SearchResult: 搜索结果对象
|
||||||
# """
|
# """
|
||||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# # 从kwargs中获取参数
|
# # 从kwargs中获取参数
|
||||||
# alpha = kwargs.get("alpha", self.alpha)
|
# alpha = kwargs.get("alpha", self.alpha)
|
||||||
@@ -107,14 +107,14 @@
|
|||||||
# # 并行执行关键词搜索和语义搜索
|
# # 并行执行关键词搜索和语义搜索
|
||||||
# keyword_result = await self.keyword_strategy.search(
|
# keyword_result = await self.keyword_strategy.search(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list
|
# include=include_list
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# semantic_result = await self.semantic_strategy.search(
|
# semantic_result = await self.semantic_strategy.search(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list
|
# include=include_list
|
||||||
# )
|
# )
|
||||||
@@ -139,7 +139,7 @@
|
|||||||
# metadata = self._create_metadata(
|
# metadata = self._create_metadata(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# search_type="hybrid",
|
# search_type="hybrid",
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# include=include_list,
|
# include=include_list,
|
||||||
# alpha=alpha,
|
# alpha=alpha,
|
||||||
@@ -165,7 +165,7 @@
|
|||||||
# metadata=self._create_metadata(
|
# metadata=self._create_metadata(
|
||||||
# query_text=query_text,
|
# query_text=query_text,
|
||||||
# search_type="hybrid",
|
# search_type="hybrid",
|
||||||
# group_id=group_id,
|
# end_user_id=end_user_id,
|
||||||
# limit=limit,
|
# limit=limit,
|
||||||
# error=str(e)
|
# error=str(e)
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表
|
include: 要包含的搜索类别列表
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
SearchResult: 搜索结果对象
|
SearchResult: 搜索结果对象
|
||||||
"""
|
"""
|
||||||
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# 获取有效的搜索类别
|
# 获取有效的搜索类别
|
||||||
include_list = self._get_include_list(include)
|
include_list = self._get_include_list(include)
|
||||||
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
results_dict = await search_graph(
|
results_dict = await search_graph(
|
||||||
connector=self.connector,
|
connector=self.connector,
|
||||||
q=query_text,
|
q=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
metadata = self._create_metadata(
|
metadata = self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="keyword",
|
search_type="keyword",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
|||||||
metadata=self._create_metadata(
|
metadata=self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="keyword",
|
search_type="keyword",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
|
|||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
|
|||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
search_type: 搜索类型
|
search_type: 搜索类型
|
||||||
group_id: 组ID
|
end_user_id: 组ID
|
||||||
limit: 结果限制
|
limit: 结果限制
|
||||||
**kwargs: 其他元数据
|
**kwargs: 其他元数据
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"query": query_text,
|
"query": query_text,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"group_id": group_id,
|
"end_user_id": end_user_id,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_text: 查询文本
|
query_text: 查询文本
|
||||||
group_id: 可选的组ID过滤
|
end_user_id: 可选的组ID过滤
|
||||||
limit: 每个类别的最大结果数
|
limit: 每个类别的最大结果数
|
||||||
include: 要包含的搜索类别列表
|
include: 要包含的搜索类别列表
|
||||||
**kwargs: 其他搜索参数
|
**kwargs: 其他搜索参数
|
||||||
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
SearchResult: 搜索结果对象
|
SearchResult: 搜索结果对象
|
||||||
"""
|
"""
|
||||||
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||||
|
|
||||||
# 获取有效的搜索类别
|
# 获取有效的搜索类别
|
||||||
include_list = self._get_include_list(include)
|
include_list = self._get_include_list(include)
|
||||||
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
connector=self.connector,
|
connector=self.connector,
|
||||||
embedder_client=self.embedder_client,
|
embedder_client=self.embedder_client,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
metadata = self._create_metadata(
|
metadata = self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="semantic",
|
search_type="semantic",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include=include_list
|
include=include_list
|
||||||
)
|
)
|
||||||
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
|||||||
metadata=self._create_metadata(
|
metadata=self._create_metadata(
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
search_type="semantic",
|
search_type="semantic",
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
|||||||
target_keys = [
|
target_keys = [
|
||||||
"id",
|
"id",
|
||||||
"statement",
|
"statement",
|
||||||
"group_id",
|
"end_user_id",
|
||||||
"chunk_id",
|
"chunk_id",
|
||||||
"created_at",
|
"created_at",
|
||||||
"expired_at",
|
"expired_at",
|
||||||
@@ -75,7 +75,7 @@ async def get_data(result):
|
|||||||
"""
|
"""
|
||||||
EXCLUDE_FIELDS = {
|
EXCLUDE_FIELDS = {
|
||||||
"user_id",
|
"user_id",
|
||||||
"group_id",
|
"end_user_id",
|
||||||
"entity_type",
|
"entity_type",
|
||||||
"connect_strength",
|
"connect_strength",
|
||||||
"relationship_type",
|
"relationship_type",
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class ConfigAuditLogger:
|
|||||||
self,
|
self,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
details: Optional[Dict[str, Any]] = None
|
details: Optional[Dict[str, Any]] = None
|
||||||
):
|
):
|
||||||
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
|
|||||||
Args:
|
Args:
|
||||||
config_id: 配置 ID
|
config_id: 配置 ID
|
||||||
user_id: 用户 ID(可选)
|
user_id: 用户 ID(可选)
|
||||||
group_id: 组 ID(可选)
|
end_user_id: 组 ID(可选)
|
||||||
success: 是否成功
|
success: 是否成功
|
||||||
details: 详细信息(可选)
|
details: 详细信息(可选)
|
||||||
"""
|
"""
|
||||||
result = "SUCCESS" if success else "FAILED"
|
result = "SUCCESS" if success else "FAILED"
|
||||||
msg = (
|
msg = (
|
||||||
f"CONFIG_LOAD config_id={config_id} "
|
f"CONFIG_LOAD config_id={config_id} "
|
||||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
|
||||||
f"result={result}"
|
f"result={result}"
|
||||||
)
|
)
|
||||||
if details:
|
if details:
|
||||||
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
|
|||||||
self,
|
self,
|
||||||
operation: str,
|
operation: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
group_id: str,
|
end_user_id: str,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
|
|||||||
Args:
|
Args:
|
||||||
operation: 操作类型(WRITE, READ 等)
|
operation: 操作类型(WRITE, READ 等)
|
||||||
config_id: 配置 ID
|
config_id: 配置 ID
|
||||||
group_id: 组 ID
|
end_user_id: 组 ID
|
||||||
success: 是否成功
|
success: 是否成功
|
||||||
duration: 操作耗时(秒)
|
duration: 操作耗时(秒)
|
||||||
error: 错误信息(可选)
|
error: 错误信息(可选)
|
||||||
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
|
|||||||
result = "SUCCESS" if success else "FAILED"
|
result = "SUCCESS" if success else "FAILED"
|
||||||
msg = (
|
msg = (
|
||||||
f"{operation.upper()} config_id={config_id} "
|
f"{operation.upper()} config_id={config_id} "
|
||||||
f"group={group_id} result={result}"
|
f"group={end_user_id} result={result}"
|
||||||
)
|
)
|
||||||
if duration is not None:
|
if duration is not None:
|
||||||
msg += f" duration={duration:.2f}s"
|
msg += f" duration={duration:.2f}s"
|
||||||
|
|||||||
1
api/app/core/models/scripts/__init__.py
Normal file
1
api/app/core/models/scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""模型配置脚本模块"""
|
||||||
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
provider: bedrock
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: ai21
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon nova
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: bedrock
|
||||||
|
- name: anthropic claude
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- document
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: deepseek
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: meta
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: mistral
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: openai
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: qwen
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.rerank-v1:0
|
||||||
|
type: rerank
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.rerank-v3-5:0
|
||||||
|
type: rerank
|
||||||
|
provider: bedrock
|
||||||
|
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
- vision
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.titan-embed-text-v1
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.titan-embed-text-v2:0
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.embed-english-v3
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.embed-multilingual-v3
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
@@ -0,0 +1,820 @@
|
|||||||
|
provider: dashscope
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: deepseek-r1-distill-qwen-14b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-r1-distill-qwen-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-r1
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.1
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.2-exp
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.2
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: farui-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: glm-4.7
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qvq-max-latest
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qvq-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-coder-turbo-0919
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max-latest
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max-longcontext
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-mt-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 翻译模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-mt-turbo
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 翻译模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0112
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0125
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0723
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0806
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0919
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1125
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1127
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1220
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-0809
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-2025-01-02
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-2025-01-25
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-latest
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen2.5-0.5b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-14b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b-instruct-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b-thinking-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-30b-a3b-instruct-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-30b-a3b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-4b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-8b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-30b-a3b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-480b-a35b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-plus-2025-09-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-2025-09-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-2026-01-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-preview
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-next-80b-a3b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-next-80b-a3b-thinking
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-omni-flash-2025-12-01
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-235b-a22b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-235b-a22b-thinking
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-30b-a3b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-30b-a3b-thinking
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-flash
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-plus-2025-09-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-plus-0305
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: gte-rerank-v2
|
||||||
|
type: rerank
|
||||||
|
provider: dashscope
|
||||||
|
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: dashscope
|
||||||
|
- name: gte-rerank
|
||||||
|
type: rerank
|
||||||
|
provider: dashscope
|
||||||
|
description: gte-rerank重排序模型,4000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: dashscope
|
||||||
|
- name: multimodal-embedding-v1
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v1
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v2
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v3
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v4
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
143
api/app/core/models/scripts/loader.py
Normal file
143
api/app/core/models/scripts/loader.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.models.models_model import ModelBase, ModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||||
|
"""从YAML文件加载指定供应商的模型配置"""
|
||||||
|
config_dir = Path(__file__).parent
|
||||||
|
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# 检查是否需要加载(默认为 true)
|
||||||
|
if not data.get('enabled', True):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return data.get('models', [])
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||||
|
"""将YAML文件的enabled标志设置为false"""
|
||||||
|
config_dir = Path(__file__).parent
|
||||||
|
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
data['enabled'] = False
|
||||||
|
|
||||||
|
with open(config_file, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
加载模型配置到数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
providers: 要加载的供应商列表,None表示加载所有
|
||||||
|
silent: 是否静默模式(不输出详细日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 加载结果统计 {"success": int, "skipped": int, "failed": int}
|
||||||
|
"""
|
||||||
|
result = {"success": 0, "skipped": 0, "failed": 0}
|
||||||
|
|
||||||
|
# 确定要加载的供应商
|
||||||
|
if providers:
|
||||||
|
target_providers = [ModelProvider(p) if isinstance(p, str) else p for p in providers]
|
||||||
|
else:
|
||||||
|
target_providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
|
||||||
|
for provider in target_providers:
|
||||||
|
# 从YAML文件加载模型配置
|
||||||
|
models = _load_yaml_config(provider)
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
if not silent:
|
||||||
|
print(f"警告: 供应商 '{provider.value}' 暂无预定义模型")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not silent:
|
||||||
|
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||||
|
|
||||||
|
# provider_success = 0
|
||||||
|
for model_data in models:
|
||||||
|
try:
|
||||||
|
# 检查模型是否已存在
|
||||||
|
existing = db.query(ModelBase).filter(
|
||||||
|
ModelBase.name == model_data["name"],
|
||||||
|
ModelBase.provider == model_data["provider"]
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 更新现有模型配置
|
||||||
|
for key, value in model_data.items():
|
||||||
|
setattr(existing, key, value)
|
||||||
|
db.commit()
|
||||||
|
if not silent:
|
||||||
|
print(f"更新成功: {model_data['name']}")
|
||||||
|
result["success"] += 1
|
||||||
|
# provider_success += 1
|
||||||
|
else:
|
||||||
|
# 创建新模型
|
||||||
|
model = ModelBase(**model_data)
|
||||||
|
db.add(model)
|
||||||
|
db.commit()
|
||||||
|
if not silent:
|
||||||
|
print(f"添加成功: {model_data['name']}")
|
||||||
|
result["success"] += 1
|
||||||
|
# provider_success += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
if not silent:
|
||||||
|
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||||
|
result["failed"] += 1
|
||||||
|
|
||||||
|
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||||
|
# if provider_success == len(models):
|
||||||
|
_disable_yaml_config(provider)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_by_provider(db: Session, provider: str) -> dict:
|
||||||
|
"""
|
||||||
|
加载指定供应商的模型配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider: 供应商名称(字符串或ModelProvider枚举)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 加载结果统计
|
||||||
|
"""
|
||||||
|
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||||
|
return load_models(db, providers=[provider_enum])
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_providers() -> list[Callable[[], str]]:
|
||||||
|
"""获取所有可用的供应商列表(从ModelProvider枚举获取,排除COMPOSITE)"""
|
||||||
|
return [p.value for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
|
||||||
|
|
||||||
|
def get_models_by_provider(provider: str) -> list[dict]:
|
||||||
|
"""获取指定供应商的模型配置列表"""
|
||||||
|
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||||
|
return _load_yaml_config(provider_enum)
|
||||||
294
api/app/core/models/scripts/openai_models.yaml
Normal file
294
api/app/core/models/scripts/openai_models.yaml
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
provider: openai
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: chatgpt-4o-latest
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-0125
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-1106
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-16k
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-instruct
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-0125-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-1106-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo-2024-04-09
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: o1-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: openai
|
||||||
|
- name: o1
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-2025-04-16
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-mini-2025-01-31
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-mini
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-pro-2025-06-10
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-pro
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o4-mini-2025-04-16
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o4-mini
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-3-large
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-3-small
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-ada-002
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import copy
|
|
||||||
import re
|
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from app.core.rag.nlp import tokenize, is_english
|
|
||||||
from app.core.rag.nlp import rag_tokenizer
|
|
||||||
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
|
|
||||||
from app.core.rag.deepdoc.parser.ppt_parser import RAGPptParser as PptParser
|
|
||||||
from PyPDF2 import PdfReader as pdf2_read
|
|
||||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
|
||||||
|
|
||||||
class Ppt(PptParser):
|
|
||||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
|
||||||
txts = super().__call__(fnm, from_page, to_page)
|
|
||||||
|
|
||||||
callback(0.5, "Text extraction finished.")
|
|
||||||
import aspose.slides as slides
|
|
||||||
import aspose.pydrawing as drawing
|
|
||||||
imgs = []
|
|
||||||
with slides.Presentation(BytesIO(fnm)) as presentation:
|
|
||||||
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
|
||||||
try:
|
|
||||||
with BytesIO() as buffered:
|
|
||||||
slide.get_thumbnail(
|
|
||||||
0.1, 0.1).save(
|
|
||||||
buffered, drawing.imaging.ImageFormat.jpeg)
|
|
||||||
buffered.seek(0)
|
|
||||||
imgs.append(Image.open(buffered).copy())
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
|
|
||||||
assert len(imgs) == len(
|
|
||||||
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
|
||||||
callback(0.9, "Image extraction finished")
|
|
||||||
self.is_english = is_english(txts)
|
|
||||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
|
||||||
|
|
||||||
class Pdf(PdfParser):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def __garbage(self, txt):
|
|
||||||
txt = txt.lower().strip()
|
|
||||||
if re.match(r"[0-9\.,%/-]+$", txt):
|
|
||||||
return True
|
|
||||||
if len(txt) < 3:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __call__(self, filename, binary=None, from_page=0,
|
|
||||||
to_page=100000, zoomin=3, callback=None):
|
|
||||||
from timeit import default_timer as timer
|
|
||||||
start = timer()
|
|
||||||
callback(msg="OCR started")
|
|
||||||
self.__images__(filename if not binary else binary,
|
|
||||||
zoomin, from_page, to_page, callback)
|
|
||||||
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
|
|
||||||
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
|
||||||
len(self.boxes), len(self.page_images))
|
|
||||||
res = []
|
|
||||||
for i in range(len(self.boxes)):
|
|
||||||
lines = "\n".join([b["text"] for b in self.boxes[i]
|
|
||||||
if not self.__garbage(b["text"])])
|
|
||||||
res.append((lines, self.page_images[i]))
|
|
||||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
|
||||||
from_page, min(to_page, self.total_page)))
|
|
||||||
return res, []
|
|
||||||
|
|
||||||
|
|
||||||
class PlainPdf(PlainParser):
|
|
||||||
def __call__(self, filename, binary=None, from_page=0,
|
|
||||||
to_page=100000, callback=None, **kwargs):
|
|
||||||
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
|
||||||
page_txt = []
|
|
||||||
for page in self.pdf.pages[from_page: to_page]:
|
|
||||||
page_txt.append(page.extract_text())
|
|
||||||
callback(0.9, "Parsing finished")
|
|
||||||
return [(txt, None) for txt in page_txt], []
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|
||||||
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
|
|
||||||
"""
|
|
||||||
The supported file formats are pdf, pptx.
|
|
||||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
|
||||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
|
||||||
"""
|
|
||||||
if parser_config is None:
|
|
||||||
parser_config = {}
|
|
||||||
eng = lang.lower() == "english"
|
|
||||||
doc = {
|
|
||||||
"docnm_kwd": filename,
|
|
||||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
||||||
}
|
|
||||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
|
||||||
res = []
|
|
||||||
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
|
||||||
if not binary:
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
binary = f.read()
|
|
||||||
ppt_parser = Ppt()
|
|
||||||
for pn, (txt, img) in enumerate(ppt_parser(
|
|
||||||
filename if not binary else binary, from_page, 1000000, callback)):
|
|
||||||
d = copy.deepcopy(doc)
|
|
||||||
pn += from_page
|
|
||||||
d["image"] = img
|
|
||||||
d["doc_type_kwd"] = "image"
|
|
||||||
d["page_num_int"] = [pn + 1]
|
|
||||||
d["top_int"] = [0]
|
|
||||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
|
||||||
tokenize(d, txt, eng)
|
|
||||||
res.append(d)
|
|
||||||
return res
|
|
||||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
|
||||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
|
||||||
|
|
||||||
if isinstance(layout_recognizer, bool):
|
|
||||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
|
||||||
|
|
||||||
name = layout_recognizer.strip().lower()
|
|
||||||
parser = PARSERS.get(name, by_plaintext)
|
|
||||||
callback(0.1, "Start to parse.")
|
|
||||||
|
|
||||||
sections, _, _ = parser(
|
|
||||||
filename=filename,
|
|
||||||
binary=binary,
|
|
||||||
from_page=from_page,
|
|
||||||
to_page=to_page,
|
|
||||||
lang=lang,
|
|
||||||
callback=callback,
|
|
||||||
vision_model=vision_model,
|
|
||||||
pdf_cls=Pdf,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not sections:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if name in ["tcadp", "docling", "mineru"]:
|
|
||||||
parser_config["chunk_token_num"] = 0
|
|
||||||
|
|
||||||
callback(0.8, "Finish parsing.")
|
|
||||||
|
|
||||||
for pn, (txt, img) in enumerate(sections):
|
|
||||||
d = copy.deepcopy(doc)
|
|
||||||
pn += from_page
|
|
||||||
if img:
|
|
||||||
d["image"] = img
|
|
||||||
d["page_num_int"] = [pn + 1]
|
|
||||||
d["top_int"] = [0]
|
|
||||||
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
|
||||||
tokenize(d, txt, eng)
|
|
||||||
res.append(d)
|
|
||||||
return res
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
"file type not supported yet(pptx, pdf supported)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def dummy(a, b):
|
|
||||||
pass
|
|
||||||
chunk(sys.argv[1], callback=dummy)
|
|
||||||
@@ -4,7 +4,7 @@ from enum import StrEnum, auto
|
|||||||
class Field(StrEnum):
|
class Field(StrEnum):
|
||||||
CONTENT_KEY = "page_content"
|
CONTENT_KEY = "page_content"
|
||||||
METADATA_KEY = "metadata"
|
METADATA_KEY = "metadata"
|
||||||
GROUP_KEY = "group_id"
|
GROUP_KEY = "end_user_id"
|
||||||
VECTOR = auto()
|
VECTOR = auto()
|
||||||
# Sparse Vector aims to support full text search
|
# Sparse Vector aims to support full text search
|
||||||
SPARSE_VECTOR = auto()
|
SPARSE_VECTOR = auto()
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def generate_signed_url(
|
|||||||
"""
|
"""
|
||||||
if base_url is None:
|
if base_url is None:
|
||||||
# Use SERVER_IP or default to localhost
|
# Use SERVER_IP or default to localhost
|
||||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
base_url = server_url
|
base_url = server_url
|
||||||
|
|
||||||
# Calculate expiration timestamp
|
# Calculate expiration timestamp
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class BaiduSearchTool(BuiltinTool):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
|
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、视频搜索"
|
||||||
|
|
||||||
def get_required_config_parameters(self) -> List[str]:
|
def get_required_config_parameters(self) -> List[str]:
|
||||||
return ["api_key"]
|
return ["api_key"]
|
||||||
@@ -33,7 +33,7 @@ class BaiduSearchTool(BuiltinTool):
|
|||||||
ToolParameter(
|
ToolParameter(
|
||||||
name="search_type",
|
name="search_type",
|
||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
description="搜索类型",
|
description="搜索类型, web: 网页搜索;news:新闻搜索;image:图片搜索;video视频搜索",
|
||||||
required=False,
|
required=False,
|
||||||
default="web",
|
default="web",
|
||||||
enum=["web", "news", "image", "video"]
|
enum=["web", "news", "image", "video"]
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ logger = get_config_logger()
|
|||||||
|
|
||||||
|
|
||||||
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
||||||
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||||
"""Parse model ID from string or UUID."""
|
"""Parse model ID from string or UUID."""
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
return None
|
return None
|
||||||
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
|
|||||||
model_type: str,
|
model_type: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Validate that a model exists and is active.
|
"""Validate that a model exists and is active.
|
||||||
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
|
|||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
config_id: Optional[int] = None,
|
config_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> tuple[Optional[UUID], Optional[str]]:
|
) -> tuple[Optional[UUID], Optional[str]]:
|
||||||
"""Validate and resolve a model ID, checking existence and active status.
|
"""Validate and resolve a model ID, checking existence and active status.
|
||||||
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
|
|||||||
|
|
||||||
|
|
||||||
def validate_embedding_model(
|
def validate_embedding_model(
|
||||||
config_id: int,
|
config_id: UUID,
|
||||||
embedding_id: Union[str, UUID, None],
|
embedding_id: Union[str, UUID, None],
|
||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
@@ -256,7 +256,7 @@ def validate_embedding_model(
|
|||||||
|
|
||||||
|
|
||||||
def validate_llm_model(
|
def validate_llm_model(
|
||||||
config_id: int,
|
config_id: UUID,
|
||||||
llm_id: Union[str, UUID, None],
|
llm_id: Union[str, UUID, None],
|
||||||
db: Session,
|
db: Session,
|
||||||
tenant_id: Optional[UUID] = None,
|
tenant_id: Optional[UUID] = None,
|
||||||
|
|||||||
@@ -11,17 +11,12 @@ from typing import Any
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.graph_builder import GraphBuilder
|
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||||
|
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
# from app.core.tools.registry import ToolRegistry
|
|
||||||
# from app.core.tools.executor import ToolExecutor
|
|
||||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
|
||||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
|
||||||
# from app.db import get_db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -55,6 +50,8 @@ class WorkflowExecutor:
|
|||||||
self.execution_config = workflow_config.get("execution_config", {})
|
self.execution_config = workflow_config.get("execution_config", {})
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id = None
|
||||||
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
|
self.activate_end: str | None = None
|
||||||
|
|
||||||
self.checkpoint_config = RunnableConfig(
|
self.checkpoint_config = RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
@@ -127,7 +124,6 @@ class WorkflowExecutor:
|
|||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"error": None,
|
"error": None,
|
||||||
"error_node": None,
|
"error_node": None,
|
||||||
"streaming_buffer": {}, # 流式缓冲区
|
|
||||||
"cycle_nodes": [
|
"cycle_nodes": [
|
||||||
node.get("id")
|
node.get("id")
|
||||||
for node in self.workflow_config.get("nodes")
|
for node in self.workflow_config.get("nodes")
|
||||||
@@ -139,9 +135,8 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_final_output(self, result, elapsed_time):
|
def _build_final_output(self, result, elapsed_time, final_output):
|
||||||
node_outputs = result.get("node_outputs", {})
|
node_outputs = result.get("node_outputs", {})
|
||||||
final_output = self._extract_final_output(node_outputs)
|
|
||||||
token_usage = self._aggregate_token_usage(node_outputs)
|
token_usage = self._aggregate_token_usage(node_outputs)
|
||||||
conversation_id = None
|
conversation_id = None
|
||||||
for node_id, node_output in node_outputs.items():
|
for node_id, node_output in node_outputs.items():
|
||||||
@@ -161,6 +156,146 @@ class WorkflowExecutor:
|
|||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _update_scope_activate(self, scope, status=None):
|
||||||
|
"""
|
||||||
|
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||||
|
|
||||||
|
Iterates over all End nodes in `self.end_outputs` and calls
|
||||||
|
`update_activate` on each, which may:
|
||||||
|
- Activate variable segments that depend on the completed node/scope.
|
||||||
|
- Activate the entire End node output if all control conditions are met.
|
||||||
|
|
||||||
|
If any End node becomes active and `self.activate_end` is not yet set,
|
||||||
|
this node will be marked as the currently active End node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): The node ID or scope that has completed execution.
|
||||||
|
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||||
|
"""
|
||||||
|
for node in self.end_outputs.keys():
|
||||||
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
|
if self.end_outputs[node].activate and self.activate_end is None:
|
||||||
|
self.activate_end = node
|
||||||
|
|
||||||
|
def _update_stream_output_status(self, activate, data):
|
||||||
|
"""
|
||||||
|
Update the stream output state of End nodes based on workflow state updates.
|
||||||
|
|
||||||
|
This method checks which nodes/scopes are activated and propagates
|
||||||
|
activation to End nodes accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||||
|
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
For each node in `data`:
|
||||||
|
1. If the node is activated (`activate[node_id]` is True),
|
||||||
|
retrieve its output status from `runtime_vars`.
|
||||||
|
2. Call `_update_scope_activate` to propagate the activation
|
||||||
|
to all relevant End nodes and update `self.activate_end`.
|
||||||
|
"""
|
||||||
|
for node_id in data.keys():
|
||||||
|
if activate.get(node_id):
|
||||||
|
node_output_status = (
|
||||||
|
data[node_id]
|
||||||
|
.get('runtime_vars', {})
|
||||||
|
.get(node_id)
|
||||||
|
.get("output")
|
||||||
|
)
|
||||||
|
self._update_scope_activate(node_id, status=node_output_status)
|
||||||
|
|
||||||
|
async def _emit_active_chunks(
|
||||||
|
self,
|
||||||
|
node_outputs: dict,
|
||||||
|
variables: dict,
|
||||||
|
force=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process and yield all currently active output segments for the currently active End node.
|
||||||
|
|
||||||
|
This method handles stream-mode output for an End node by iterating through its output segments
|
||||||
|
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||||
|
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||||
|
2. For each segment:
|
||||||
|
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||||
|
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||||
|
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||||
|
then transform the result with `_trans_output_string`.
|
||||||
|
3. Yield a stream event of type "message" containing the processed chunk.
|
||||||
|
4. Move the `cursor` forward after processing each segment.
|
||||||
|
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||||
|
and reset `activate_end` to None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
||||||
|
variables (dict): Current runtime variables, used for variable evaluation.
|
||||||
|
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: A stream event of type "message" containing the processed chunk.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||||
|
- This method only processes the currently active End node (`self.activate_end`).
|
||||||
|
- Use `force=True` for final emission regardless of activation state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
|
||||||
|
while end_info.cursor < len(end_info.outputs):
|
||||||
|
final_chunk = ''
|
||||||
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
|
if not current_segment.activate and not force:
|
||||||
|
# Stop processing until this segment becomes active
|
||||||
|
break
|
||||||
|
|
||||||
|
# Literal segment
|
||||||
|
if not current_segment.is_variable:
|
||||||
|
final_chunk += current_segment.literal
|
||||||
|
else:
|
||||||
|
# Variable segment: evaluate and transform
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
current_segment.literal,
|
||||||
|
variables=variables,
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
chunk = self._trans_output_string(chunk)
|
||||||
|
final_chunk += chunk
|
||||||
|
except ValueError:
|
||||||
|
# Log failed evaluation but continue streaming
|
||||||
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||||
|
|
||||||
|
if final_chunk:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": final_chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Advance cursor after processing
|
||||||
|
end_info.cursor += 1
|
||||||
|
|
||||||
|
# Remove End node from active tracking if all segments have been processed
|
||||||
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trans_output_string(content):
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
return "\n".join(content)
|
||||||
|
else:
|
||||||
|
return str(content)
|
||||||
|
|
||||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -173,6 +308,7 @@ class WorkflowExecutor:
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
|
self.end_outputs = builder.end_node_map
|
||||||
graph = builder.build()
|
graph = builder.build()
|
||||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||||
|
|
||||||
@@ -205,14 +341,28 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
|
|
||||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
|
full_content = ''
|
||||||
|
for end_id in self.end_outputs.keys():
|
||||||
|
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
return self._build_final_output(result, elapsed_time)
|
return self._build_final_output(result, elapsed_time, full_content)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算耗时(即使失败也记录)
|
# 计算耗时(即使失败也记录)
|
||||||
@@ -261,7 +411,7 @@ class WorkflowExecutor:
|
|||||||
"data": {
|
"data": {
|
||||||
"execution_id": self.execution_id,
|
"execution_id": self.execution_id,
|
||||||
"workspace_id": self.workspace_id,
|
"workspace_id": self.workspace_id,
|
||||||
"timestamp": start_time.isoformat()
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,7 +423,8 @@ class WorkflowExecutor:
|
|||||||
# 3. Execute workflow
|
# 3. Execute workflow
|
||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
full_content = ''
|
||||||
|
self._update_scope_activate("sys")
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
@@ -293,20 +444,42 @@ class WorkflowExecutor:
|
|||||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
if event_type == "node_chunk":
|
||||||
f"- execution_id: {self.execution_id}")
|
node_id = data.get("node_id")
|
||||||
yield {
|
if self.activate_end:
|
||||||
"event": event_type, # "message" or "node_chunk"
|
end_info = self.end_outputs.get(self.activate_end)
|
||||||
"data": {
|
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||||
"node_id": data.get("node_id"),
|
continue
|
||||||
"chunk": data.get("chunk"),
|
current_output = end_info.outputs[end_info.cursor]
|
||||||
"full_content": data.get("full_content"),
|
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||||
"chunk_index": data.get("chunk_index"),
|
if data.get("done"):
|
||||||
"is_prefix": data.get("is_prefix"),
|
end_info.cursor += 1
|
||||||
"is_suffix": data.get("is_suffix"),
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
"conversation_id": input_data.get("conversation_id"),
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
else:
|
||||||
|
full_content += data.get("chunk")
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": data.get("chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||||
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
|
elif event_type == "node_error":
|
||||||
|
yield {
|
||||||
|
"event": event_type, # "message" or "node_chunk"
|
||||||
|
"data": {
|
||||||
|
"node_id": data.get("node_id"),
|
||||||
|
"status": "failed",
|
||||||
|
"input": data.get("input_data"),
|
||||||
|
"elapsed_time": data.get("elapsed_time"),
|
||||||
|
"output": None,
|
||||||
|
"error": data.get("error")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
elif mode == "debug":
|
elif mode == "debug":
|
||||||
# Handle debug information (node execution status)
|
# Handle debug information (node execution status)
|
||||||
@@ -325,14 +498,15 @@ class WorkflowExecutor:
|
|||||||
conversation_id = input_data.get("conversation_id")
|
conversation_id = input_data.get("conversation_id")
|
||||||
logger.info(f"[NODE-START] Node starts execution: {node_name} "
|
logger.info(f"[NODE-START] Node starts execution: {node_name} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "node_start",
|
"event": "node_start",
|
||||||
"data": {
|
"data": {
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"execution_id": self.execution_id,
|
"execution_id": self.execution_id,
|
||||||
"timestamp": data.get("timestamp"),
|
"timestamp": int(datetime.datetime.fromisoformat(
|
||||||
|
data.get("timestamp")
|
||||||
|
).timestamp() * 1000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
elif event_type == "task_result":
|
elif event_type == "task_result":
|
||||||
@@ -351,21 +525,82 @@ class WorkflowExecutor:
|
|||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"execution_id": self.execution_id,
|
"execution_id": self.execution_id,
|
||||||
"timestamp": data.get("timestamp"),
|
"timestamp": int(datetime.datetime.fromisoformat(
|
||||||
"state": result.get("node_outputs", {}).get(node_name),
|
data.get("timestamp")
|
||||||
|
).timestamp() * 1000),
|
||||||
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||||
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||||
|
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
# TODO:流式输出点
|
state = graph.get_state(config=self.checkpoint_config).values
|
||||||
|
node_outputs = state.get("runtime_vars", {})
|
||||||
|
variables = state.get("variables", {})
|
||||||
|
activate = state.get("activate", {})
|
||||||
|
for _, node_data in data.items():
|
||||||
|
node_outputs |= node_data.get("runtime_vars", {})
|
||||||
|
variables |= node_data.get("variables", {})
|
||||||
|
|
||||||
|
self._update_stream_output_status(activate, data)
|
||||||
|
wait = False
|
||||||
|
while self.activate_end and not wait:
|
||||||
|
async for msg_event in self._emit_active_chunks(
|
||||||
|
node_outputs=node_outputs,
|
||||||
|
variables=variables
|
||||||
|
):
|
||||||
|
full_content += msg_event["data"]['chunk']
|
||||||
|
yield msg_event
|
||||||
|
|
||||||
|
if self.activate_end:
|
||||||
|
wait = True
|
||||||
|
else:
|
||||||
|
self._update_stream_output_status(activate, data)
|
||||||
|
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
|
node_outputs = result.get("runtime_vars", {})
|
||||||
|
variables = result.get("variables", {})
|
||||||
|
self.end_outputs = {
|
||||||
|
node_id: node_info
|
||||||
|
for node_id, node_info in self.end_outputs.items()
|
||||||
|
if node_info.activate
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.end_outputs or self.activate_end:
|
||||||
|
while self.activate_end:
|
||||||
|
async for msg_event in self._emit_active_chunks(
|
||||||
|
node_outputs=node_outputs,
|
||||||
|
variables=variables,
|
||||||
|
force=True
|
||||||
|
):
|
||||||
|
full_content += msg_event["data"]['chunk']
|
||||||
|
yield msg_event
|
||||||
|
|
||||||
|
if not self.activate_end and self.end_outputs:
|
||||||
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
result = graph.get_state(self.checkpoint_config).values
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
|
logger.info(result)
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Workflow execution completed (streaming), "
|
f"Workflow execution completed (streaming), "
|
||||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||||
@@ -374,7 +609,7 @@ class WorkflowExecutor:
|
|||||||
# 发送 workflow_end 事件
|
# 发送 workflow_end 事件
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self._build_final_output(result, elapsed_time)
|
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -396,31 +631,6 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
|
||||||
"""从节点输出中提取最终输出
|
|
||||||
|
|
||||||
优先级:
|
|
||||||
1. 最后一个执行的非 start/end 节点的 output
|
|
||||||
2. 如果没有节点输出,返回 None
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_outputs: 所有节点的输出
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最终输出字符串或 None
|
|
||||||
"""
|
|
||||||
if not node_outputs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取最后一个节点的输出
|
|
||||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
|
||||||
|
|
||||||
if last_node_output and isinstance(last_node_output, dict):
|
|
||||||
return last_node_output.get("output")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||||
"""聚合所有节点的 token 使用情况
|
"""聚合所有节点的 token 使用情况
|
||||||
@@ -511,178 +721,3 @@ async def execute_workflow_stream(
|
|||||||
)
|
)
|
||||||
async for event in executor.execute_stream(input_data):
|
async for event in executor.execute_stream(input_data):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# ==================== 工具管理系统集成 ====================
|
|
||||||
|
|
||||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
|
||||||
# """获取工作流可用的工具列表
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# workspace_id: 工作空间ID
|
|
||||||
# user_id: 用户ID
|
|
||||||
#
|
|
||||||
# Returns:
|
|
||||||
# 可用工具列表
|
|
||||||
# """
|
|
||||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
|
||||||
# logger.warning("工具管理系统不可用")
|
|
||||||
# return []
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# db = next(get_db())
|
|
||||||
#
|
|
||||||
# # 创建工具注册表
|
|
||||||
# registry = ToolRegistry(db)
|
|
||||||
#
|
|
||||||
# # 注册内置工具类
|
|
||||||
# from app.core.tools.builtin import (
|
|
||||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
|
||||||
# )
|
|
||||||
# registry.register_tool_class(DateTimeTool)
|
|
||||||
# registry.register_tool_class(JsonTool)
|
|
||||||
# registry.register_tool_class(BaiduSearchTool)
|
|
||||||
# registry.register_tool_class(MinerUTool)
|
|
||||||
# registry.register_tool_class(TextInTool)
|
|
||||||
#
|
|
||||||
# # 获取活跃的工具
|
|
||||||
# import uuid
|
|
||||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
|
||||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
|
||||||
#
|
|
||||||
# # 转换为Langchain工具
|
|
||||||
# langchain_tools = []
|
|
||||||
# for tool_info in active_tools:
|
|
||||||
# try:
|
|
||||||
# tool_instance = registry.get_tool(tool_info.id)
|
|
||||||
# if tool_instance:
|
|
||||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
|
||||||
# langchain_tools.append(langchain_tool)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
|
||||||
#
|
|
||||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
|
||||||
# return langchain_tools
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"获取工作流工具失败: {e}")
|
|
||||||
# return []
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class ToolWorkflowNode:
|
|
||||||
# """工具工作流节点 - 在工作流中执行工具"""
|
|
||||||
#
|
|
||||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
|
||||||
# """初始化工具节点
|
|
||||||
#
|
|
||||||
# Args:
|
|
||||||
# node_config: 节点配置
|
|
||||||
# workflow_config: 工作流配置
|
|
||||||
# """
|
|
||||||
# self.node_config = node_config
|
|
||||||
# self.workflow_config = workflow_config
|
|
||||||
# self.tool_id = node_config.get("tool_id")
|
|
||||||
# self.tool_parameters = node_config.get("parameters", {})
|
|
||||||
#
|
|
||||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
|
||||||
# """执行工具节点"""
|
|
||||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
|
||||||
# logger.error("工具管理系统不可用")
|
|
||||||
# state["error"] = "工具管理系统不可用"
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# from sqlalchemy.orm import Session
|
|
||||||
# db = next(get_db())
|
|
||||||
#
|
|
||||||
# # 创建工具执行器
|
|
||||||
# registry = ToolRegistry(db)
|
|
||||||
# executor = ToolExecutor(db, registry)
|
|
||||||
#
|
|
||||||
# # 准备参数(支持变量替换)
|
|
||||||
# parameters = self._prepare_parameters(state)
|
|
||||||
#
|
|
||||||
# # 执行工具
|
|
||||||
# result = await executor.execute_tool(
|
|
||||||
# tool_id=self.tool_id,
|
|
||||||
# parameters=parameters,
|
|
||||||
# user_id=uuid.UUID(state["user_id"]),
|
|
||||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# # 更新状态
|
|
||||||
# node_id = self.node_config.get("id")
|
|
||||||
# if result.success:
|
|
||||||
# state["node_outputs"][node_id] = {
|
|
||||||
# "type": "tool",
|
|
||||||
# "tool_id": self.tool_id,
|
|
||||||
# "output": result.data,
|
|
||||||
# "execution_time": result.execution_time,
|
|
||||||
# "token_usage": result.token_usage
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# # 更新运行时变量
|
|
||||||
# if isinstance(result.data, dict):
|
|
||||||
# for key, value in result.data.items():
|
|
||||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
|
||||||
# else:
|
|
||||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
|
||||||
# else:
|
|
||||||
# state["error"] = result.error
|
|
||||||
# state["error_node"] = node_id
|
|
||||||
# state["node_outputs"][node_id] = {
|
|
||||||
# "type": "tool",
|
|
||||||
# "tool_id": self.tool_id,
|
|
||||||
# "error": result.error,
|
|
||||||
# "execution_time": result.execution_time
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"工具节点执行失败: {e}")
|
|
||||||
# state["error"] = str(e)
|
|
||||||
# state["error_node"] = self.node_config.get("id")
|
|
||||||
# return state
|
|
||||||
#
|
|
||||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
|
||||||
# """准备工具参数(支持变量替换)"""
|
|
||||||
# parameters = {}
|
|
||||||
#
|
|
||||||
# for key, value in self.tool_parameters.items():
|
|
||||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
|
||||||
# # 变量替换
|
|
||||||
# var_path = value[2:-1]
|
|
||||||
#
|
|
||||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
|
||||||
# if "." in var_path:
|
|
||||||
# parts = var_path.split(".")
|
|
||||||
# current = state.get("variables", {})
|
|
||||||
#
|
|
||||||
# for part in parts:
|
|
||||||
# if isinstance(current, dict) and part in current:
|
|
||||||
# current = current[part]
|
|
||||||
# else:
|
|
||||||
# # 尝试从运行时变量获取
|
|
||||||
# runtime_key = ".".join(parts)
|
|
||||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# parameters[key] = current
|
|
||||||
# else:
|
|
||||||
# # 简单变量
|
|
||||||
# variables = state.get("variables", {})
|
|
||||||
# parameters[key] = variables.get(var_path, value)
|
|
||||||
# else:
|
|
||||||
# parameters[key] = value
|
|
||||||
#
|
|
||||||
# return parameters
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# # 注册工具节点到NodeFactory(如果存在)
|
|
||||||
# try:
|
|
||||||
# from app.core.workflow.nodes import NodeFactory
|
|
||||||
# if hasattr(NodeFactory, 'register_node_type'):
|
|
||||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
|
||||||
# logger.info("工具节点已注册到工作流系统")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.warning(f"注册工具节点失败: {e}")
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||||
from langgraph.types import Send
|
from langgraph.types import Send
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
@@ -15,6 +18,149 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputContent(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single output segment of an End node.
|
||||||
|
|
||||||
|
An output segment can be either:
|
||||||
|
- literal text (static string)
|
||||||
|
- a variable placeholder (e.g. {{ node.field }})
|
||||||
|
|
||||||
|
Each segment has its own activation state, which is especially
|
||||||
|
important in stream mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
literal: str = Field(
|
||||||
|
...,
|
||||||
|
description="Raw output content. Can be literal text or a variable placeholder."
|
||||||
|
)
|
||||||
|
|
||||||
|
activate: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether this output segment is currently active.\n"
|
||||||
|
"- True: allowed to be emitted/output\n"
|
||||||
|
"- False: blocked until activated by branch control"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_variable: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether this segment represents a variable placeholder.\n"
|
||||||
|
"True -> variable (e.g. {{ node.field }})\n"
|
||||||
|
"False -> literal text"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def depends_on_scope(self, scope: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this segment depends on a given scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if this segment references the given scope.
|
||||||
|
"""
|
||||||
|
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||||
|
return bool(re.search(pattern, self.literal))
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOutputConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Streaming output configuration for an End node.
|
||||||
|
|
||||||
|
This configuration describes how the End node output behaves in streaming mode,
|
||||||
|
including:
|
||||||
|
- whether output emission is globally activated
|
||||||
|
- which upstream branch/control nodes gate the activation
|
||||||
|
- how each parsed output segment is streamed and activated
|
||||||
|
"""
|
||||||
|
|
||||||
|
activate: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Global activation flag for the End node output.\n"
|
||||||
|
"When False, output segments should not be emitted even if available.\n"
|
||||||
|
"This flag typically becomes True once required control branch conditions "
|
||||||
|
"are satisfied."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_nodes: dict[str, str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Control branch conditions for this End node output.\n"
|
||||||
|
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||||
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
|
"reports a matching completion status."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs: list[OutputContent] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Ordered list of output segments parsed from the output template.\n"
|
||||||
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
|
"that may be activated independently."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor: int = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Streaming cursor index.\n"
|
||||||
|
"Indicates the next output segment index to be emitted.\n"
|
||||||
|
"Segments with index < cursor are considered already streamed."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_activate(self, scope: str, status=None):
|
||||||
|
"""
|
||||||
|
Update streaming activation state based on an upstream node or special variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str):
|
||||||
|
Identifier of the completed upstream entity.
|
||||||
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||||
|
status (optional):
|
||||||
|
Completion status of the control branch node.
|
||||||
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
1. Control branch nodes:
|
||||||
|
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||||
|
branch label, the End node output becomes globally active (`activate = True`).
|
||||||
|
|
||||||
|
2. Variable output segments:
|
||||||
|
- For each segment that is a variable (`is_variable=True`):
|
||||||
|
- If the segment literal references `scope`, mark the segment as active.
|
||||||
|
- This applies both to regular node variables (e.g., "node_id.field")
|
||||||
|
and special system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This method does not emit output or advance the streaming cursor.
|
||||||
|
- It only updates activation flags based on upstream events or special variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Case 1: resolve control branch dependency
|
||||||
|
if scope in self.control_nodes.keys():
|
||||||
|
if status is None:
|
||||||
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
|
if status == self.control_nodes[scope]:
|
||||||
|
self.activate = True
|
||||||
|
|
||||||
|
# Case 2: activate variable segments related to this node
|
||||||
|
for i in range(len(self.outputs)):
|
||||||
|
if (
|
||||||
|
self.outputs[i].is_variable
|
||||||
|
and self.outputs[i].depends_on_scope(scope)
|
||||||
|
):
|
||||||
|
self.outputs[i].activate = True
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -29,10 +175,16 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id = None
|
||||||
self.end_node_ids = []
|
self.end_node_ids = []
|
||||||
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
|
self._find_upstream_branch_node = lru_cache(
|
||||||
|
maxsize=len(self.nodes) * 2
|
||||||
|
)(self._find_upstream_branch_node)
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
|
self._analyze_end_node_output()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -43,79 +195,207 @@ class GraphBuilder:
|
|||||||
def edges(self) -> list[dict[str, Any]]:
|
def edges(self) -> list[dict[str, Any]]:
|
||||||
return self.workflow_config.get("edges", [])
|
return self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
def get_node_type(self, node_id: str) -> str:
|
||||||
"""
|
"""Retrieve the type of node given its ID.
|
||||||
Analyze the prefix configuration for End nodes.
|
|
||||||
|
|
||||||
This function scans each End node's output template, identifies
|
Args:
|
||||||
references to its direct upstream nodes, and extracts the prefix
|
node_id (str): The unique identifier of the node.
|
||||||
string appearing before the first reference.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple:
|
str: The type of the node.
|
||||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
|
||||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
Raises:
|
||||||
|
RuntimeError: If no node with the given `node_id` exists.
|
||||||
"""
|
"""
|
||||||
import re
|
try:
|
||||||
|
return self.node_map[node_id]["type"]
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||||
|
|
||||||
prefixes = {}
|
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
"""
|
||||||
|
Recursively find all upstream branch (control) nodes that influence the execution
|
||||||
|
of the given target node.
|
||||||
|
|
||||||
# 找到所有 End 节点
|
This method walks upstream along the workflow graph starting from `target_node`.
|
||||||
|
It distinguishes between:
|
||||||
|
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||||
|
- non-branch nodes (ordinary processing nodes)
|
||||||
|
|
||||||
|
Traversal rules:
|
||||||
|
1. For each immediate upstream node:
|
||||||
|
- If it is a branch node, it is recorded as an affecting control node.
|
||||||
|
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||||
|
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||||
|
a branch node, the traversal is considered invalid:
|
||||||
|
- `has_branch` will be False
|
||||||
|
- no branch nodes are returned.
|
||||||
|
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||||
|
branch node will `has_branch` be True.
|
||||||
|
|
||||||
|
Special case:
|
||||||
|
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||||
|
it is considered directly reachable from the workflow entry, and therefore
|
||||||
|
has no controlling branch nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_node (str):
|
||||||
|
The identifier of the node whose upstream control branches
|
||||||
|
are to be resolved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, tuple[tuple[str, str]]]:
|
||||||
|
- has_branch (bool):
|
||||||
|
True if every upstream path from `target_node` encounters
|
||||||
|
at least one branch node.
|
||||||
|
False if any path reaches a start node without a branch.
|
||||||
|
- branch_nodes (tuple[tuple[str, str]]):
|
||||||
|
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||||
|
representing all branch nodes that can influence `target_node`.
|
||||||
|
Returns an empty tuple if `has_branch` is False.
|
||||||
|
"""
|
||||||
|
source_nodes = [
|
||||||
|
{
|
||||||
|
"id": edge.get("source"),
|
||||||
|
"branch": edge.get("label")
|
||||||
|
}
|
||||||
|
for edge in self.edges
|
||||||
|
if edge.get("target") == target_node
|
||||||
|
]
|
||||||
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
return False, tuple()
|
||||||
|
|
||||||
|
branch_nodes = []
|
||||||
|
non_branch_nodes = []
|
||||||
|
|
||||||
|
for node_info in source_nodes:
|
||||||
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
|
branch_nodes.append(
|
||||||
|
(node_info["id"], node_info["branch"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
|
has_branch = True
|
||||||
|
for node_id in non_branch_nodes:
|
||||||
|
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||||
|
has_branch = has_branch and node_has_branch
|
||||||
|
if not has_branch:
|
||||||
|
break
|
||||||
|
branch_nodes.extend(nodes)
|
||||||
|
if not has_branch:
|
||||||
|
branch_nodes = []
|
||||||
|
|
||||||
|
return has_branch, tuple(set(branch_nodes))
|
||||||
|
|
||||||
|
def _analyze_end_node_output(self):
|
||||||
|
"""
|
||||||
|
Analyze output templates of all End nodes and generate StreamOutputConfig.
|
||||||
|
|
||||||
|
This method is responsible for parsing the `output` field of End nodes,
|
||||||
|
splitting literal text and variable placeholders (e.g. {{ node.field }}),
|
||||||
|
and determining whether each output segment should be activated immediately
|
||||||
|
or controlled by upstream branch nodes.
|
||||||
|
|
||||||
|
In stream mode:
|
||||||
|
- If the End node is controlled by any upstream branch node, the output
|
||||||
|
will be initially inactive and controlled by those branch nodes.
|
||||||
|
- Otherwise, the output is activated immediately.
|
||||||
|
|
||||||
|
In non-stream mode:
|
||||||
|
- All outputs are activated by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||||
|
|
||||||
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
output_template = end_node.get("config", {}).get("output")
|
config = end_node.get("config", {})
|
||||||
|
output = config.get("output")
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
# Skip End nodes without output configuration
|
||||||
|
if not output:
|
||||||
if not output_template:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Find all node references in the template
|
# Regex to split output into:
|
||||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
# - variable placeholders: {{ ... }}
|
||||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
# - normal literal text
|
||||||
matches = list(re.finditer(pattern, output_template))
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
variable_pattern = re.compile(variable_pattern_string)
|
||||||
|
|
||||||
# Identify all direct upstream nodes connected to the End node
|
# Split output into ordered segments
|
||||||
direct_upstream_nodes = []
|
output_template = list(re.findall(pattern, output))
|
||||||
for edge in self.edges:
|
|
||||||
if edge.get("target") == end_node_id:
|
|
||||||
source_node_id = edge.get("source")
|
|
||||||
direct_upstream_nodes.append(source_node_id)
|
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
# Determine whether each segment is literal text
|
||||||
|
# True -> literal (can be directly output)
|
||||||
|
# False -> variable placeholder (needs runtime value)
|
||||||
|
output_flag = [
|
||||||
|
not bool(variable_pattern.match(item))
|
||||||
|
for item in output_template
|
||||||
|
]
|
||||||
|
|
||||||
# 找到第一个直接上游节点的引用
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
for match in matches:
|
if self.stream:
|
||||||
referenced_node_id = match.group(1)
|
# Find upstream branch nodes that can control this End node
|
||||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||||
|
|
||||||
if referenced_node_id in direct_upstream_nodes:
|
# Build StreamOutputConfig for this End node
|
||||||
# 这是直接上游节点的引用,提取前缀
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
prefix = output_template[:match.start()]
|
# If there is no upstream branch, output is active immediately
|
||||||
|
activate=not has_branch,
|
||||||
|
|
||||||
logger.info(f"[Prefix Analysis] "
|
# Branch nodes that control activation of this End node
|
||||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
control_nodes=dict(control_nodes),
|
||||||
|
|
||||||
# 标记这个节点为"相邻且被引用"
|
# Convert output segments into OutputContent objects
|
||||||
adjacent_and_referenced.add(referenced_node_id)
|
outputs=list(
|
||||||
|
[
|
||||||
|
OutputContent(
|
||||||
|
literal=output_string,
|
||||||
|
# Literal text can be activated immediately unless blocked by branch
|
||||||
|
activate=activate,
|
||||||
|
# Variable segments are marked explicitly
|
||||||
|
is_variable=not activate
|
||||||
|
)
|
||||||
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
# Cursor for streaming output (initially 0)
|
||||||
|
cursor=0
|
||||||
|
)
|
||||||
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
|
f"activate: {not has_branch}, "
|
||||||
|
f"control_nodes: {control_nodes},"
|
||||||
|
f"output: {output_template},"
|
||||||
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
if prefix:
|
# Non-stream mode: all outputs are activated by default
|
||||||
prefixes[referenced_node_id] = prefix
|
else:
|
||||||
logger.info(f"[Prefix Analysis] "
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
activate=True,
|
||||||
|
control_nodes={},
|
||||||
# 只处理第一个直接上游节点的引用
|
outputs=list(
|
||||||
break
|
[
|
||||||
|
OutputContent(
|
||||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
literal=output_string,
|
||||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
activate=True,
|
||||||
return prefixes, adjacent_and_referenced
|
is_variable=not activate
|
||||||
|
)
|
||||||
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
cursor=0
|
||||||
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
"""Add all nodes from the workflow configuration to the state graph.
|
"""Add all nodes from the workflow configuration to the state graph.
|
||||||
@@ -135,9 +415,6 @@ class GraphBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Analyze End node prefixes if in stream mode
|
|
||||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
|
||||||
|
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
@@ -171,17 +448,6 @@ class GraphBuilder:
|
|||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Inject End node prefix configuration if in stream mode
|
|
||||||
if self.stream and node_id in end_prefixes:
|
|
||||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
|
||||||
logger.info(f"Injected End prefix for node {node_id}")
|
|
||||||
|
|
||||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
|
||||||
if self.stream:
|
|
||||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
|
||||||
if node_id in adjacent_and_referenced:
|
|
||||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
|
||||||
|
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Stream mode: create an async generator function
|
# Stream mode: create an async generator function
|
||||||
@@ -261,6 +527,7 @@ class GraphBuilder:
|
|||||||
for source_node, branches in conditional_edges.items():
|
for source_node, branches in conditional_edges.items():
|
||||||
def make_router(src, branch_list):
|
def make_router(src, branch_list):
|
||||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||||
|
|
||||||
def make_branch_node(node_name, targets):
|
def make_branch_node(node_name, targets):
|
||||||
def node(s):
|
def node(s):
|
||||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||||
|
|||||||
@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
|
|||||||
error: str | None
|
error: str | None
|
||||||
error_node: str | None
|
error_node: str | None
|
||||||
|
|
||||||
# Streaming buffer (stores real-time streaming output of nodes)
|
|
||||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
|
||||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
|
||||||
|
|
||||||
# node activate status
|
# node activate status
|
||||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||||
|
|
||||||
@@ -300,7 +296,7 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.check_activate(state):
|
if not self.check_activate(state):
|
||||||
yield self.trans_activate(state)
|
yield self.trans_activate(state)
|
||||||
logger.info(f"跳过节点{self.node_id}")
|
logger.info(f"jump node: {self.node_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -313,19 +309,6 @@ class BaseNode(ABC):
|
|||||||
# Get LangGraph's stream writer for sending custom data
|
# Get LangGraph's stream writer for sending custom data
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
|
|
||||||
# Check if this is an End node
|
|
||||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
|
||||||
is_end_node = self.node_type == "end"
|
|
||||||
|
|
||||||
# Check if this node is adjacent to End node (for message type)
|
|
||||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
|
||||||
|
|
||||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
|
||||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
|
||||||
|
|
||||||
# Accumulate complete result (for final wrapping)
|
# Accumulate complete result (for final wrapping)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_result = None
|
final_result = None
|
||||||
@@ -340,66 +323,25 @@ class BaseNode(ABC):
|
|||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
# Check if it's a completion marker
|
# Check if it's a completion marker
|
||||||
if isinstance(item, dict) and item.get("__final__"):
|
if item.get("__final__"):
|
||||||
final_result = item["result"]
|
final_result = item["result"]
|
||||||
elif isinstance(item, str):
|
else:
|
||||||
# String is a chunk
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
chunks.append(item)
|
content = str(item.get("chunk"))
|
||||||
full_content = "".join(chunks)
|
done = item.get("done", False)
|
||||||
|
chunks.append(content)
|
||||||
|
|
||||||
# Send chunks for all nodes (including End nodes for suffix)
|
# Send chunks for all nodes (including End nodes for suffix)
|
||||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||||
|
|
||||||
# 1. Send via stream writer (for real-time client updates)
|
# 1. Send via stream writer (for real-time client updates)
|
||||||
writer({
|
writer({
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
"type": "node_chunk",
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"chunk": item,
|
"chunk": content,
|
||||||
"full_content": full_content,
|
"done": done
|
||||||
"chunk_index": chunk_count
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# 2. Update streaming buffer in state (for downstream nodes)
|
|
||||||
# Only non-End nodes need streaming buffer
|
|
||||||
if not is_end_node:
|
|
||||||
yield {
|
|
||||||
"streaming_buffer": {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Other types are also treated as chunks
|
|
||||||
chunk_count += 1
|
|
||||||
chunk_str = str(item)
|
|
||||||
chunks.append(chunk_str)
|
|
||||||
full_content = "".join(chunks)
|
|
||||||
|
|
||||||
# Send chunks for all nodes
|
|
||||||
writer({
|
|
||||||
"type": chunk_type, # "message" or "node_chunk"
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"chunk": chunk_str,
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_index": chunk_count
|
|
||||||
})
|
|
||||||
|
|
||||||
# Only non-End nodes need streaming buffer
|
|
||||||
if not is_end_node:
|
|
||||||
yield {
|
|
||||||
"streaming_buffer": {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": full_content,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||||
@@ -426,16 +368,6 @@ class BaseNode(ABC):
|
|||||||
"looping": state["looping"]
|
"looping": state["looping"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add streaming buffer for non-End nodes
|
|
||||||
if not is_end_node:
|
|
||||||
state_update["streaming_buffer"] = {
|
|
||||||
self.node_id: {
|
|
||||||
"full_content": "".join(chunks),
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"is_complete": True # Mark as complete
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Finally yield state update
|
# Finally yield state update
|
||||||
# LangGraph will merge this into state
|
# LangGraph will merge this into state
|
||||||
yield state_update | self.trans_activate(state)
|
yield state_update | self.trans_activate(state)
|
||||||
@@ -544,6 +476,11 @@ class BaseNode(ABC):
|
|||||||
"error_node": self.node_id
|
"error_node": self.node_id
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
writer = get_stream_writer()
|
||||||
|
writer({
|
||||||
|
"type": "node_error",
|
||||||
|
**node_output
|
||||||
|
})
|
||||||
# 无错误边:抛出异常停止工作流
|
# 无错误边:抛出异常停止工作流
|
||||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from app.core.workflow.nodes.code.node import CodeNode
|
||||||
|
|
||||||
|
__all__ = ["CodeNode"]
|
||||||
|
|||||||
50
api/app/core/workflow/nodes/code/config.py
Normal file
50
api/app/core/workflow/nodes/code/config.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class InputVariable(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="variable name"
|
||||||
|
)
|
||||||
|
|
||||||
|
variable: str = Field(
|
||||||
|
...,
|
||||||
|
description="variable selector"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputVariable(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="variable name"
|
||||||
|
)
|
||||||
|
|
||||||
|
type: VariableType = Field(
|
||||||
|
...,
|
||||||
|
description="variable selector"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeNodeConfig(BaseNodeConfig):
|
||||||
|
input_variables: list[InputVariable] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="input variables"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_variables: list[OutputVariable] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="output variables"
|
||||||
|
)
|
||||||
|
|
||||||
|
code: str = Field(
|
||||||
|
default="",
|
||||||
|
description="code content"
|
||||||
|
)
|
||||||
|
|
||||||
|
language: Literal['python3', 'nodejs'] = Field(
|
||||||
|
...,
|
||||||
|
description="language"
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user