style(memory): Some code style optimizations

This commit is contained in:
Eternity
2026-03-20 18:22:20 +08:00
parent e8ae46b286
commit c17a2dad2d
8 changed files with 296 additions and 292 deletions

View File

@@ -54,8 +54,8 @@ router = APIRouter(
@router.get("/info", response_model=ApiResponse) @router.get("/info", response_model=ApiResponse)
async def get_storage_info( async def get_storage_info(
storage_id: str, storage_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Example wrapper endpoint - retrieves storage information Example wrapper endpoint - retrieves storage information
@@ -75,24 +75,19 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config( def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
try: try:
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型) # 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
@@ -107,9 +102,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}") api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -119,9 +116,11 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -129,10 +128,10 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: UUID|int, config_id: UUID | int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""删除记忆配置(带终端用户保护) """删除记忆配置(带终端用户保护)
@@ -145,24 +144,24 @@ def delete_config(
force: 设置为 true 可强制删除(即使有终端用户正在使用) force: 设置为 true 可强制删除(即使有终端用户正在使用)
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info( api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}" f"config_id={config_id}, force={force}"
) )
try: try:
# 使用带保护的删除服务 # 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force) result = config_service.delete_config(config_id=config_id, force=force)
if result["status"] == "error": if result["status"] == "error":
api_logger.warning( api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
@@ -172,7 +171,7 @@ def delete_config(
msg=result["message"], msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)} data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
) )
if result["status"] == "warning": if result["status"] == "warning":
api_logger.warning( api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, " f"记忆配置正在使用,无法删除: config_id={config_id}, "
@@ -186,7 +185,7 @@ def delete_config(
"force_required": result["force_required"] "force_required": result["force_required"]
} }
) )
api_logger.info( api_logger.info(
f"记忆配置删除成功: config_id={config_id}, " f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}" f"affected_users={result['affected_users']}"
@@ -195,7 +194,7 @@ def delete_config(
msg=result["message"], msg=result["message"],
data={"affected_users": result["affected_users"]} data={"affected_users": result["affected_users"]}
) )
except Exception as e: except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@@ -203,9 +202,9 @@ def delete_config(
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config( def update_config(
payload: ConfigUpdate, payload: ConfigUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -213,12 +212,13 @@ def update_config(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新 # 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
"config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -231,9 +231,9 @@ def update_config(
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted( def update_config_extracted(
payload: ConfigUpdateExtracted, payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -241,7 +241,7 @@ def update_config_extracted(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -256,11 +256,11 @@ def update_config_extracted(
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: UUID | int, config_id: UUID | int,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -268,7 +268,7 @@ def read_config_extracted(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -278,18 +278,19 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -303,14 +304,14 @@ def read_all_config(
@router.post("/pilot_run", response_model=None) @router.post("/pilot_run", response_model=None)
async def pilot_run( async def pilot_run(
payload: ConfigPilotRun, payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> StreamingResponse: ) -> StreamingResponse:
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
api_logger.info( api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, " f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, " f"dialogue_text_length={len(payload.dialogue_text)}, "
@@ -333,9 +334,9 @@ async def pilot_run(
@router.get("/search/kb_type_distribution", response_model=ApiResponse) @router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution( async def get_kb_type_distribution(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try: try:
result = await kb_type_distribution(end_user_id) result = await kb_type_distribution(end_user_id)
@@ -344,12 +345,12 @@ async def get_kb_type_distribution(
api_logger.error(f"KB type distribution failed: {str(e)}") api_logger.error(f"KB type distribution failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
@router.get("/search/dialogue", response_model=ApiResponse) @router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num( async def search_dialogues_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try: try:
result = await search_dialogue(end_user_id) result = await search_dialogue(end_user_id)
@@ -361,9 +362,9 @@ async def search_dialogues_num(
@router.get("/search/chunk", response_model=ApiResponse) @router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num( async def search_chunks_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try: try:
result = await search_chunk(end_user_id) result = await search_chunk(end_user_id)
@@ -375,9 +376,9 @@ async def search_chunks_num(
@router.get("/search/statement", response_model=ApiResponse) @router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num( async def search_statements_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try: try:
result = await search_statement(end_user_id) result = await search_statement(end_user_id)
@@ -389,9 +390,9 @@ async def search_statements_num(
@router.get("/search/entity", response_model=ApiResponse) @router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num( async def search_entities_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try: try:
result = await search_entity(end_user_id) result = await search_entity(end_user_id)
@@ -403,9 +404,9 @@ async def search_entities_num(
@router.get("/search", response_model=ApiResponse) @router.get("/search", response_model=ApiResponse)
async def search_all_num( async def search_all_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}") api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try: try:
result = await search_all(end_user_id) result = await search_all(end_user_id)
@@ -417,9 +418,9 @@ async def search_all_num(
@router.get("/search/detials", response_model=ApiResponse) @router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials( async def search_entities_detials(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}") api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try: try:
result = await search_detials(end_user_id) result = await search_detials(end_user_id)
@@ -431,9 +432,9 @@ async def search_entities_detials(
@router.get("/search/edges", response_model=ApiResponse) @router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges( async def search_entity_edges(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try: try:
result = await search_edges(end_user_id) result = await search_edges(end_user_id)
@@ -443,14 +444,12 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api( async def get_hot_memory_tags_api(
limit: int = 10, limit: int = 10,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
获取热门记忆标签带Redis缓存 获取热门记忆标签带Redis缓存
@@ -461,18 +460,18 @@ async def get_hot_memory_tags_api(
- 缓存未命中:~600-800ms取决于LLM速度 - 缓存未命中:~600-800ms取决于LLM速度
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 构建缓存键 # 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}" cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
try: try:
# 尝试从Redis缓存获取 # 尝试从Redis缓存获取
import json import json
from app.aioRedis import aio_redis_get, aio_redis_set from app.aioRedis import aio_redis_get, aio_redis_set
cached_result = await aio_redis_get(cache_key) cached_result = await aio_redis_get(cache_key)
if cached_result: if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}") api_logger.info(f"Cache hit for key: {cache_key}")
@@ -481,11 +480,11 @@ async def get_hot_memory_tags_api(
return success(data=data, msg="查询成功(缓存)") return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError: except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh") api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询 # 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query") api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit) result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟 # 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串 # 注意result是列表需要转换为JSON字符串
try: try:
@@ -495,9 +494,9 @@ async def get_hot_memory_tags_api(
except Exception as cache_error: except Exception as cache_error:
# 缓存写入失败不影响主流程 # 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}") api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}") api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache( async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
清除热门标签缓存 清除热门标签缓存
@@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache(
- 数据更新后立即生效 - 数据更新后立即生效
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try: try:
from app.aioRedis import aio_redis_delete from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值 # 清除所有limit的缓存常见的limit值
cleared_count = 0 cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]: for limit in [5, 10, 15, 20, 30, 50]:
@@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache(
if result: if result:
cleared_count += 1 cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}") api_logger.info(f"Cleared cache for key: {cache_key}")
return success( return success(
data={"cleared_count": cleared_count}, data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存" msg=f"成功清除 {cleared_count} 个缓存"
) )
except Exception as e: except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}") api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))

View File

@@ -598,8 +598,10 @@ class LangChainAgent:
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", total_tokens = response_meta.get("token_usage", {}).get(
0) if response_meta else 0 "total_tokens",
0
) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:

View File

@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
""" """
if v is None: if v is None:
return v return v
# 处理 Neo4j DateTime 对象 # 处理 Neo4j DateTime 对象
if hasattr(v, 'to_native'): if hasattr(v, 'to_native'):
return v.to_native() return v.to_native()
# 处理 Python datetime 对象 # 处理 Python datetime 对象
if isinstance(v, datetime): if isinstance(v, datetime):
return v return v
if isinstance(v, str): if isinstance(v, str):
# 匹配 ISO 8601 格式YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 匹配 ISO 8601 格式YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
# 支持1-4位年份 # 支持1-4位年份
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
match = re.match(pattern, v) match = re.match(pattern, v)
if match: if match:
try: try:
year = int(match.group(1)) year = int(match.group(1))
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
minute = int(match.group(5)) if match.group(5) else 0 minute = int(match.group(5)) if match.group(5) else 0
second = int(match.group(6)) if match.group(6) else 0 second = int(match.group(6)) if match.group(6) else 0
microsecond = 0 microsecond = 0
# 处理微秒 # 处理微秒
if match.group(7): if match.group(7):
# 补齐或截断到6位 # 补齐或截断到6位
us_str = match.group(7).ljust(6, '0')[:6] us_str = match.group(7).ljust(6, '0')[:6]
microsecond = int(us_str) microsecond = int(us_str)
# 处理时区 # 处理时区
tzinfo = None tzinfo = None
if 'Z' in v or match.group(8): if 'Z' in v or match.group(8):
tzinfo = timezone.utc tzinfo = timezone.utc
# 创建 datetime 对象 # 创建 datetime 对象
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
except (ValueError, OverflowError): except (ValueError, OverflowError):
# 日期值无效如月份13、日期32等 # 日期值无效如月份13、日期32等
return None return None
# 如果不匹配模式,尝试使用 fromisoformat用于标准格式 # 如果不匹配模式,尝试使用 fromisoformat用于标准格式
try: try:
return datetime.fromisoformat(v.replace('Z', '+00:00')) return datetime.fromisoformat(v.replace('Z', '+00:00'))
except Exception: except Exception:
return None return None
return v return v
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
source_statement_id: str = Field(..., description="Statement where this relationship was extracted") source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start") valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
@@ -206,7 +206,8 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog") ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content") content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node): class StatementNode(Node):
@@ -241,17 +242,17 @@ class StatementNode(Node):
chunk_id: str = Field(..., description="ID of the parent chunk") chunk_id: str = Field(..., description="ID of the parent chunk")
stmt_type: str = Field(..., description="Type of the statement") stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content") statement: str = Field(..., description="The statement text content")
# Speaker identification # Speaker identification
speaker: Optional[str] = Field( speaker: Optional[str] = Field(
None, None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
) )
# Emotion fields (ordered as requested, emotion_intensity first for display) # Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field( emotion_intensity: Optional[float] = Field(
None, None,
ge=0.0, ge=0.0,
le=1.0, le=1.0,
description="Emotion intensity: 0.0-1.0 (displayed on node)" description="Emotion intensity: 0.0-1.0 (displayed on node)"
) )
@@ -264,25 +265,26 @@ class StatementNode(Node):
description="Emotion subject: self/other/object" description="Emotion subject: self/other/object"
) )
emotion_type: Optional[str] = Field( emotion_type: Optional[str] = Field(
None, None,
description="Emotion type: joy/sadness/anger/fear/surprise/neutral" description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
) )
emotion_keywords: Optional[List[str]] = Field( emotion_keywords: Optional[List[str]] = Field(
default_factory=list, default_factory=list,
description="Emotion keywords list, max 3 items" description="Emotion keywords list, max 3 items"
) )
# Temporal fields # Temporal fields
temporal_info: TemporalInfo = Field(..., description="Temporal information") temporal_info: TemporalInfo = Field(..., description="Temporal information")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start") valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
# Embedding and other fields # Embedding and other fields
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,
@@ -309,13 +311,13 @@ class StatementNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed" description="Total number of times this node has been accessed"
) )
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
"""使用通用的历史日期解析函数""" """使用通用的历史日期解析函数"""
return parse_historical_datetime(v) return parse_historical_datetime(v)
@field_validator('emotion_type', mode='before') @field_validator('emotion_type', mode='before')
@classmethod @classmethod
def validate_emotion_type(cls, v): def validate_emotion_type(cls, v):
@@ -326,7 +328,7 @@ class StatementNode(Node):
if v not in valid_types: if v not in valid_types:
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
return v return v
@field_validator('emotion_subject', mode='before') @field_validator('emotion_subject', mode='before')
@classmethod @classmethod
def validate_emotion_subject(cls, v): def validate_emotion_subject(cls, v):
@@ -337,7 +339,7 @@ class StatementNode(Node):
if v not in valid_subjects: if v not in valid_subjects:
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
return v return v
@field_validator('emotion_keywords', mode='before') @field_validator('emotion_keywords', mode='before')
@classmethod @classmethod
def validate_emotion_keywords(cls, v): def validate_emotion_keywords(cls, v):
@@ -405,19 +407,20 @@ class ExtractedEntityNode(Node):
entity_type: str = Field(..., description="Type of the entity") entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description") description: str = Field(..., description="Entity description")
example: str = Field( example: str = Field(
default="", default="",
description="A concise example (around 20 characters) to help understand the entity" description="A concise example (around 20 characters) to help understand the entity"
) )
aliases: List[str] = Field( aliases: List[str] = Field(
default_factory=list, default_factory=list,
description="Entity aliases - alternative names for this entity" description="Entity aliases - alternative names for this entity"
) )
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity") # fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,
@@ -444,16 +447,16 @@ class ExtractedEntityNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed" description="Total number of times this node has been accessed"
) )
# Explicit Memory Classification # Explicit Memory Classification
is_explicit_memory: bool = Field( is_explicit_memory: bool = Field(
default=False, default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
) )
@field_validator('aliases', mode='before') @field_validator('aliases', mode='before')
@classmethod @classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
"""Validate and clean aliases field using utility function. """Validate and clean aliases field using utility function.
This validator ensures that the aliases field is always a valid list of strings. This validator ensures that the aliases field is always a valid list of strings.
@@ -507,8 +510,9 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") config_id: Optional[int | str] = Field(None,
description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties # ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field( original_statement_id: Optional[str] = Field(
None, None,
@@ -522,7 +526,7 @@ class MemorySummaryNode(Node):
None, None,
description="Timestamp when the nodes were merged" description="Timestamp when the nodes were merged"
) )
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,

View File

@@ -227,7 +227,8 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度 # 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))): for i in range(min(5, len(embeddings))):
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") print(f"实体 '{entity_texts[i]}' "
f"嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体 # 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings): for ent, emb in zip(entity_refs, embeddings):

View File

@@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# Entity Merge Query # Entity Merge Query
MERGE_ENTITIES = """ MERGE_ENTITIES = """
MATCH (canonical:ExtractedEntity {id: $canonical_id}) MATCH (canonical:ExtractedEntity {id: $canonical_id})
@@ -829,9 +828,8 @@ neo4j_query_all = """
other as entity2 other as entity2
""" """
'''针对当前节点下扩长的句子,实体和总结''' '''针对当前节点下扩长的句子,实体和总结'''
Memory_Timeline_ExtractedEntity=""" Memory_Timeline_ExtractedEntity = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND (ms:ExtractedEntity OR ms:MemorySummary) AND (ms:ExtractedEntity OR ms:MemorySummary)
@@ -869,7 +867,7 @@ RETURN
""" """
Memory_Timeline_MemorySummary=""" Memory_Timeline_MemorySummary = """
MATCH (n)-[r1]-(e)-[r2]-(ms) MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) =$id WHERE elementId(n) =$id
AND (ms:MemorySummary OR ms:ExtractedEntity) AND (ms:MemorySummary OR ms:ExtractedEntity)
@@ -904,7 +902,7 @@ RETURN
} }
) AS statement; ) AS statement;
""" """
Memory_Timeline_Statement=""" Memory_Timeline_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
@@ -947,7 +945,7 @@ RETURN
""" """
'''针对当前节点,主要获取更加完整的句子节点''' '''针对当前节点,主要获取更加完整的句子节点'''
Memory_Space_Emotion_Statement=""" Memory_Space_Emotion_Statement = """
MATCH (n) MATCH (n)
WHERE elementId(n) = $id WHERE elementId(n) = $id
RETURN RETURN
@@ -957,7 +955,7 @@ RETURN
n.statement AS statement; n.statement AS statement;
""" """
Memory_Space_Emotion_MemorySummary=""" Memory_Space_Emotion_MemorySummary = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -970,7 +968,7 @@ RETURN DISTINCT
e.emotion_type AS emotion_type, e.emotion_type AS emotion_type,
e.statement AS statement; e.statement AS statement;
""" """
Memory_Space_Emotion_ExtractedEntity=""" Memory_Space_Emotion_ExtractedEntity = """
MATCH (n)-[]-(e) MATCH (n)-[]-(e)
WHERE elementId(n) = $id WHERE elementId(n) = $id
AND EXISTS { AND EXISTS {
@@ -985,18 +983,18 @@ RETURN DISTINCT
'''获取实体''' '''获取实体'''
Memory_Space_User=""" Memory_Space_User = """
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.end_user_id = $end_user_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
return DISTINCT elementId(m) as id return DISTINCT elementId(m) as id
""" """
Memory_Space_Entity=""" Memory_Space_Entity = """
MATCH (n)-[]-(m) MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person" WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN RETURN
DISTINCT m.name as name,m.end_user_id as end_user_id DISTINCT m.name as name,m.end_user_id as end_user_id
""" """
Memory_Space_Associative=""" Memory_Space_Associative = """
MATCH (u)-[]-(x)-[]-(h) MATCH (u)-[]-(x)-[]-(h)
WHERE elementId(u) = $user_id WHERE elementId(u) = $user_id
AND elementId(h) = $id AND elementId(h) = $id
@@ -1060,7 +1058,6 @@ Graph_Node_query = """
""" """
# ============================================================ # ============================================================
# Community 节点 & BELONGS_TO_COMMUNITY 边 # Community 节点 & BELONGS_TO_COMMUNITY 边
# ============================================================ # ============================================================

View File

@@ -8,9 +8,6 @@ import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
# ============================================================================ # ============================================================================
# 从 json_schema.py 迁移的 Schema # 从 json_schema.py 迁移的 Schema
# ============================================================================ # ============================================================================
@@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel):
class ConflictResultSchema(BaseModel): class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file.""" """Schema for the conflict result data in the reflexion_data.json file."""
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") data: List[BaseDataSchema] = Field(...,
description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.") conflict: bool = Field(..., description="Whether the memory is in conflict.")
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") quality_assessment: Optional[QualityAssessmentSchema] = Field(None,
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
memory_verify: Optional[MemoryVerifySchema] = Field(None,
description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_data(cls, v): def _normalize_data(cls, v):
@@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel):
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
""" """
field: List[Dict[str, Any]] = Field( field: List[Dict[str, Any]] = Field(
..., ...,
description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
) )
class ResolvedSchema(BaseModel): class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data""" """Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None,
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
change: Optional[List[ChangeRecordSchema]] = Field(None,
description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel): class SingleReflexionResultSchema(BaseModel):
@@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.") type: str = Field("reflexion_result", description="The type identifier.")
class ReflexionResultSchema(BaseModel): class ReflexionResultSchema(BaseModel):
"""Schema for the complete reflexion result data - a list of individual conflict resolutions.""" """Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") results: List[SingleReflexionResultSchema] = Field(...,
description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before") @model_validator(mode="before")
def _normalize_resolved(cls, v): def _normalize_resolved(cls, v):
@@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel):
# Composite key identifying a config row # Composite key identifying a config row
class ConfigKey(BaseModel): # 配置参数键模型 class ConfigKey(BaseModel): # 配置参数键模型
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
user_id: str = Field("user_id", description="用户标识(字符串)") user_id: str | None = Field(default=None, description="用户标识(字符串)")
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)")
# Allowed chunking strategies (extendable later) # Allowed chunking strategies (extendable later)
@@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
config_name: str = Field("配置名称", description="配置名称(字符串)") config_name: str = Field("配置名称", description="配置名称(字符串)")
config_desc: str = Field("配置描述", description="配置描述(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)")
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间IDUUID") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间IDUUID")
# 本体场景关联(可选) # 本体场景关联(可选)
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表") scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表")
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name前端无需传入 # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name前端无需传入
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
# 模型配置字段(可选,用于手动指定或自动填充) # 模型配置字段(可选,用于手动指定或自动填充)
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致") reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
@@ -255,7 +262,7 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
@@ -322,14 +329,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
# 遗忘引擎配置参数更新模型 # 遗忘引擎配置参数更新模型
config_id:Union[uuid.UUID, int, str] = None config_id: Union[uuid.UUID, int, str] = None
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5") lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度0-1 小数;默认 0.5")
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率0-1 小数;默认 0.5")
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度0-1 小数;默认 0.0")
class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -364,11 +371,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None)
def fail( def fail(
msg: str, msg: str,
error_code: str = "ERROR", error_code: str = "ERROR",
data: Optional[Any] = None, data: Optional[Any] = None,
time: Optional[int] = None, time: Optional[int] = None,
query_preview: Optional[str] = None, query_preview: Optional[str] = None,
) -> ApiResponse: ) -> ApiResponse:
payload = data payload = data
if query_preview is not None: if query_preview is not None:
@@ -387,12 +394,13 @@ def fail(
time=time or _now_ms(), time=time or _now_ms(),
) )
class GenerateCacheRequest(BaseModel): class GenerateCacheRequest(BaseModel):
"""缓存生成请求模型""" """缓存生成请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
end_user_id: Optional[str] = Field( end_user_id: Optional[str] = Field(
None, None,
description="终端用户IDUUID格式。如果提供只为该用户生成如果不提供为当前工作空间的所有用户生成" description="终端用户IDUUID格式。如果提供只为该用户生成如果不提供为当前工作空间的所有用户生成"
) )
@@ -404,7 +412,7 @@ class GenerateCacheRequest(BaseModel):
class ForgettingTriggerRequest(BaseModel): class ForgettingTriggerRequest(BaseModel):
"""手动触发遗忘周期请求模型""" """手动触发遗忘周期请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
end_user_id: str = Field(..., description="组ID即终端用户ID必填") end_user_id: str = Field(..., description="组ID即终端用户ID必填")
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数默认100") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数默认100")
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数默认30天") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数默认30天")
@@ -413,7 +421,7 @@ class ForgettingTriggerRequest(BaseModel):
class ForgettingConfigResponse(BaseModel): class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型""" """遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID支持UUID、整数或字符串")
decay_constant: float = Field(..., description="衰减常数 d") decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数") lambda_time: float = Field(..., description="时间衰减参数")
@@ -432,7 +440,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型""" """遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
@@ -448,7 +456,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
class ForgettingCycleHistoryPoint(BaseModel): class ForgettingCycleHistoryPoint(BaseModel):
"""遗忘周期历史数据点模型(用于趋势图)""" """遗忘周期历史数据点模型(用于趋势图)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
date: str = Field(..., description="日期(格式: '1/1', '1/2'") date: str = Field(..., description="日期(格式: '1/1', '1/2'")
merged_count: int = Field(..., description="每日融合节点数") merged_count: int = Field(..., description="每日融合节点数")
average_activation: Optional[float] = Field(None, description="平均激活值") average_activation: Optional[float] = Field(None, description="平均激活值")
@@ -459,7 +467,7 @@ class ForgettingCycleHistoryPoint(BaseModel):
class PendingForgettingNode(BaseModel): class PendingForgettingNode(BaseModel):
"""待遗忘节点模型""" """待遗忘节点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
node_id: str = Field(..., description="节点ID") node_id: str = Field(..., description="节点ID")
node_type: str = Field(..., description="节点类型statement/entity/summary") node_type: str = Field(..., description="节点类型statement/entity/summary")
content_summary: str = Field(..., description="内容摘要") content_summary: str = Field(..., description="内容摘要")
@@ -472,7 +480,8 @@ class ForgettingStatsResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据每天取最后一次执行") recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")
@@ -480,7 +489,7 @@ class ForgettingStatsResponse(BaseModel):
class ForgettingReportResponse(BaseModel): class ForgettingReportResponse(BaseModel):
"""遗忘周期报告响应模型""" """遗忘周期报告响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
merged_count: int = Field(..., description="融合的节点对数量") merged_count: int = Field(..., description="融合的节点对数量")
nodes_before: int = Field(..., description="遗忘前的节点总数") nodes_before: int = Field(..., description="遗忘前的节点总数")
nodes_after: int = Field(..., description="遗忘后的节点总数") nodes_after: int = Field(..., description="遗忘后的节点总数")
@@ -495,7 +504,7 @@ class ForgettingReportResponse(BaseModel):
class ForgettingCurvePoint(BaseModel): class ForgettingCurvePoint(BaseModel):
"""遗忘曲线数据点模型""" """遗忘曲线数据点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
day: int = Field(..., description="天数") day: int = Field(..., description="天数")
activation: float = Field(..., description="激活值") activation: float = Field(..., description="激活值")
retention_rate: float = Field(..., description="保持率(与激活值相同)") retention_rate: float = Field(..., description="保持率(与激活值相同)")
@@ -504,7 +513,7 @@ class ForgettingCurvePoint(BaseModel):
class ForgettingCurveRequest(BaseModel): class ForgettingCurveRequest(BaseModel):
"""遗忘曲线请求模型""" """遗忘曲线请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1") importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天") days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识UUID或int)")
@@ -513,6 +522,6 @@ class ForgettingCurveRequest(BaseModel):
class ForgettingCurveResponse(BaseModel): class ForgettingCurveResponse(BaseModel):
"""遗忘曲线响应模型""" """遗忘曲线响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表")
config: Dict[str, Any] = Field(..., description="使用的配置参数") config: Dict[str, Any] = Field(..., description="使用的配置参数")

View File

@@ -11,9 +11,11 @@ import time
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import ( from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db, get_raw_tags_from_db,
filter_tags_with_llm, filter_tags_with_llm,
) )
@@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import (
) )
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__) logger = get_logger(__name__)
config_logger = get_config_logger() config_logger = get_config_logger()
@@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector()
class MemoryStorageService: class MemoryStorageService:
"""Service for memory storage operations""" """Service for memory storage operations"""
def __init__(self): def __init__(self):
logger.info("MemoryStorageService initialized") logger.info("MemoryStorageService initialized")
async def get_storage_info(self) -> dict: async def get_storage_info(self) -> dict:
""" """
Example wrapper method - retrieves storage information Example wrapper method - retrieves storage information
@@ -59,17 +59,17 @@ class MemoryStorageService:
Storage information dictionary Storage information dictionary
""" """
logger.info("Getting storage info ") logger.info("Getting storage info ")
# Empty wrapper - implement your logic here # Empty wrapper - implement your logic here
result = { result = {
"status": "active", "status": "active",
"message": "This is an example wrapper" "message": "This is an example wrapper"
} }
return result
class DataConfigService: # 数据配置服务类PostgreSQL return result
class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD. """Service layer for config params CRUD.
使用 SQLAlchemy ORM 进行数据库操作。 使用 SQLAlchemy ORM 进行数据库操作。
@@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return data_list return data_list
# --- Create --- # --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置 # 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name: if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
@@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return None return None
# --- Delete --- # --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
success = MemoryConfigRepository.delete(self.db, key.config_id) success = MemoryConfigRepository.delete(self.db, key.config_id)
if not success: if not success:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return {"affected": 1} return {"affected": 1}
# --- Update --- # --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
config = MemoryConfigRepository.update(self.db, update) config = MemoryConfigRepository.update(self.db, update)
if not config: if not config:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return {"affected": 1} return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
config = MemoryConfigRepository.update_extracted(self.db, update) config = MemoryConfigRepository.update_extracted(self.db, update)
if not config: if not config:
raise ValueError("未找到配置") raise ValueError("未找到配置")
@@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
# --- Read --- # --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
if not result: if not result:
raise ValueError("未找到配置") raise ValueError("未找到配置")
return result return result
# --- Read All --- # --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
results = MemoryConfigRepository.get_all(self.db, workspace_id) results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 检查并修正 pruning_scene 与 scene_name 不一致的记录 # 检查并修正 pruning_scene 与 scene_name 不一致的记录
@@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
except (ValueError, TypeError): except (ValueError, TypeError):
config_id_old = None config_id_old = None
if config_id_old: if config_id_old:
memory_config=config_id_old memory_config = config_id_old
else: else:
memory_config=config.config_id memory_config = config.config_id
config_dict = { config_dict = {
"config_id": memory_config, "config_id": memory_config,
"config_name": config.config_name, "config_name": config.config_name,
@@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list) return self._convert_timestamps_to_format(data_list)
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
""" """
流式执行试运行,产生 SSE 格式的进度事件 流式执行试运行,产生 SSE 格式的进度事件
@@ -311,14 +309,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
""" """
from pathlib import Path from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) project_root = str(Path(__file__).resolve().parents[2])
try: try:
# 发出初始进度事件 # 发出初始进度事件
yield format_sse_message("starting", { yield format_sse_message("starting", {
"message": "开始试运行...", "message": "开始试运行...",
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
}) })
# 步骤 1: 配置加载和验证(数据库优先) # 步骤 1: 配置加载和验证(数据库优先)
payload_cid = str(getattr(payload, "config_id", "") or "").strip() payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None cid: Optional[str] = payload_cid if payload_cid else None
@@ -344,27 +342,28 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 关联了本体场景,优先使用 custom_text # 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text: if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip() dialogue_text = payload.custom_text.strip()
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else: else:
# 如果没有提供 custom_text回退到 dialogue_text # 如果没有提供 custom_text回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") logger.info(
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else: else:
# 没有关联本体场景,使用 dialogue_text # 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}") logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
# 验证最终使用的文本不为空 # 验证最终使用的文本不为空
if not dialogue_text: if not dialogue_text:
raise ValueError("试运行模式必须提供有效的文本内容dialogue_text 或 custom_text") raise ValueError("试运行模式必须提供有效的文本内容dialogue_text 或 custom_text")
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度 # 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件 # 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue() progress_queue: asyncio.Queue = asyncio.Queue()
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
""" """
进度回调函数,将进度事件放入队列 进度回调函数,将进度事件放入队列
@@ -375,14 +374,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
data: 可选的结果数据(用于传递节点执行结果) data: 可选的结果数据(用于传递节点执行结果)
""" """
await progress_queue.put((stage, message, data)) await progress_queue.put((stage, message, data))
# 步骤 3: 在后台任务中执行管线 # 步骤 3: 在后台任务中执行管线
async def run_pipeline(): async def run_pipeline():
"""在后台执行管线并捕获异常""" """在后台执行管线并捕获异常"""
try: try:
from app.services.pilot_run_service import run_pilot_extraction from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") logger.info(
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction( await run_pilot_extraction(
memory_config=memory_config, memory_config=memory_config,
dialogue_text=dialogue_text, dialogue_text=dialogue_text,
@@ -391,60 +391,60 @@ class DataConfigService: # 数据配置服务类PostgreSQL
language=language, language=language,
) )
logger.info("[PILOT_RUN_STREAM] pipeline_main completed") logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
# 标记管线完成 # 标记管线完成
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
except Exception as e: except Exception as e:
# 将异常放入队列 # 将异常放入队列
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
# 启动后台任务 # 启动后台任务
pipeline_task = asyncio.create_task(run_pipeline()) pipeline_task = asyncio.create_task(run_pipeline())
# 步骤 4: 从队列中读取进度事件并发出 # 步骤 4: 从队列中读取进度事件并发出
while True: while True:
try: try:
# 等待进度事件,设置超时以检测客户端断开 # 等待进度事件,设置超时以检测客户端断开
stage, message, data = await asyncio.wait_for( stage, message, data = await asyncio.wait_for(
progress_queue.get(), progress_queue.get(),
timeout=0.5 timeout=0.5
) )
# 检查特殊标记 # 检查特殊标记
if stage == "__PIPELINE_COMPLETE__": if stage == "__PIPELINE_COMPLETE__":
break break
elif stage == "__PIPELINE_ERROR__": elif stage == "__PIPELINE_ERROR__":
raise RuntimeError(message) raise RuntimeError(message)
# 构建进度事件数据 # 构建进度事件数据
progress_data = { progress_data = {
"message": message, "message": message,
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
} }
# 如果有结果数据,添加到事件中 # 如果有结果数据,添加到事件中
if data: if data:
progress_data["data"] = data progress_data["data"] = data
# 发出进度事件,使用 stage 作为事件类型 # 发出进度事件,使用 stage 作为事件类型
yield format_sse_message(stage, progress_data) yield format_sse_message(stage, progress_data)
except TimeoutError: except TimeoutError:
# 超时,继续等待(这允许检测客户端断开) # 超时,继续等待(这允许检测客户端断开)
continue continue
# 等待管线任务完成 # 等待管线任务完成
await pipeline_task await pipeline_task
# 步骤 5: 读取提取结果 # 步骤 5: 读取提取结果
from app.core.config import settings from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json") result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path): if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf: with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf) extracted_result = json.load(rf)
# 步骤 6: 计算本体覆盖率并合并到结果中 # 步骤 6: 计算本体覆盖率并合并到结果中
result_data = { result_data = {
"config_id": cid, "config_id": cid,
@@ -460,15 +460,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
result_data["ontology_coverage"] = ontology_coverage result_data["ontology_coverage"] = ontology_coverage
except Exception as cov_err: except Exception as cov_err:
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
yield format_sse_message("result", result_data) yield format_sse_message("result", result_data)
# 步骤 7: 发出完成事件 # 步骤 7: 发出完成事件
yield format_sse_message("done", { yield format_sse_message("done", {
"message": "试运行完成", "message": "试运行完成",
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
}) })
except asyncio.CancelledError: except asyncio.CancelledError:
# 客户端断开连接 # 客户端断开连接
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
@@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000) "time": int(time.time() * 1000)
}) })
async def _compute_ontology_coverage( async def _compute_ontology_coverage(
self, self,
extracted_result: Dict[str, Any], extracted_result: Dict[str, Any],
memory_config, memory_config,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
@@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD) # Ensure env for connector (e.g., NEO4J_PASSWORD)
load_dotenv()
_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
@@ -664,7 +661,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
# 检查结果是否为空或长度不足 # 检查结果是否为空或长度不足
if not result or len(result) < 4: if not result or len(result) < 4:
data = { data = {
"total": 0, "total": 0,
"distribution": [ "distribution": [
{"type": "dialogue", "count": 0}, {"type": "dialogue", "count": 0},
{"type": "chunk", "count": 0}, {"type": "chunk", "count": 0},
@@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
) )
return result return result
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,
current_user: User, current_user: User,
limit: int = 10 limit: int = 10
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取热门记忆标签按数量排序并返回前N个 获取热门记忆标签按数量排序并返回前N个
@@ -721,27 +719,27 @@ async def analytics_hot_memory_tags(
from app.services.memory_dashboard_service import get_workspace_end_users from app.services.memory_dashboard_service import get_workspace_end_users
# 使用 asyncio.to_thread 避免阻塞事件循环 # 使用 asyncio.to_thread 避免阻塞事件循环
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
if not end_users: if not end_users:
return [] return []
# 步骤1: 收集所有用户的原始标签不调用LLM # 步骤1: 收集所有用户的原始标签不调用LLM
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
all_raw_tags = [] all_raw_tags = []
for end_user in end_users: for end_user in end_users:
raw_tags = await get_raw_tags_from_db( raw_tags = await get_raw_tags_from_db(
connector, connector,
str(end_user.id), str(end_user.id),
limit=raw_limit, limit=raw_limit,
by_user=False by_user=False
) )
if raw_tags: if raw_tags:
all_raw_tags.extend(raw_tags) all_raw_tags.extend(raw_tags)
if not all_raw_tags: if not all_raw_tags:
return [] return []
# 步骤2: 聚合相同标签的频率 # 步骤2: 聚合相同标签的频率
tag_frequency_map = {} tag_frequency_map = {}
for tag_name, frequency in all_raw_tags: for tag_name, frequency in all_raw_tags:
@@ -749,36 +747,36 @@ async def analytics_hot_memory_tags(
tag_frequency_map[tag_name] += frequency tag_frequency_map[tag_name] += frequency
else: else:
tag_frequency_map[tag_name] = frequency tag_frequency_map[tag_name] = frequency
# 步骤3: 按频率降序排序取前raw_limit个 # 步骤3: 按频率降序排序取前raw_limit个
sorted_tags = sorted( sorted_tags = sorted(
tag_frequency_map.items(), tag_frequency_map.items(),
key=lambda x: x[1], key=lambda x: x[1],
reverse=True reverse=True
)[:raw_limit] )[:raw_limit]
if not sorted_tags: if not sorted_tags:
return [] return []
# 步骤4: 只调用一次LLM进行筛选 # 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags] tag_names = [tag for tag, _ in sorted_tags]
# 使用第一个用户的end_user_id来获取LLM配置 # 使用第一个用户的end_user_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置 # 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id) first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
# 步骤5: 根据LLM筛选结果构建最终列表保留频率 # 步骤5: 根据LLM筛选结果构建最终列表保留频率
final_tags = [] final_tags = []
for tag, freq in sorted_tags: for tag, freq in sorted_tags:
if tag in filtered_tag_names: if tag in filtered_tag_names:
final_tags.append((tag, freq)) final_tags.append((tag, freq))
# 步骤6: 只返回前limit个 # 步骤6: 只返回前limit个
top_tags = final_tags[:limit] top_tags = final_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags] return [{"name": t, "frequency": f} for t, f in top_tags]
finally: finally:
await connector.close() await connector.close()
@@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
source = "log" source = "log"
total = ( total = (
stats.get("chunk_count", 0) stats.get("chunk_count", 0)
+ stats.get("statements_count", 0) + stats.get("statements_count", 0)
+ stats.get("triplet_entities_count", 0) + stats.get("triplet_entities_count", 0)
+ stats.get("triplet_relations_count", 0) + stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0) + stats.get("temporal_count", 0)
) )
# 计算"最新一次活动多久前"(仅日志来源时有效) # 计算"最新一次活动多久前"(仅日志来源时有效)
@@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) ->
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data return data

View File

@@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, def write_message_task(
user_rag_memory_id: str, self,
language: str = "zh") -> Dict[str, Any]: end_user_id: str,
message: list[dict],
config_id: str | int,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
end_user_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
@@ -1105,14 +1111,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try: try:
with get_db_context() as db: with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db) actual_config_id = resolve_config_id(config_id, db)
print(100 * '-') logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} "
print(actual_config_id) f"(type: {type(actual_config_id).__name__})")
print(100 * '-')
logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e: except (ValueError, AttributeError) as e:
logger.error( logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} "
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") f"(type: {type(config_id).__name__}), error: {e}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}", "error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1151,8 +1154,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info( logger.info(f"[CELERY WRITE] Task completed successfully "
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try: try:
@@ -1167,7 +1170,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
) )
except Exception as _e: except Exception as _e:
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
@@ -2672,7 +2674,7 @@ def write_perceptual_memory(
ignore_result=False, ignore_result=False,
max_retries=0, max_retries=0,
acks_late=False, acks_late=False,
time_limit=7200, # 2小时硬超时 time_limit=7200, # 2小时硬超时
soft_time_limit=6900, soft_time_limit=6900,
) )
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
@@ -2749,7 +2751,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
llm_model_id=llm_model_id, llm_model_id=llm_model_id,
) )
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}") logger.info(
f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")
await engine.full_clustering(end_user_id) await engine.full_clustering(end_user_id)
initialized += 1 initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
@@ -2772,12 +2775,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
} }
try: try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
loop = set_asyncio_event_loop() loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time