From c17a2dad2d3a7273b1d45783c8291a2d338bbb99 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:22:20 +0800 Subject: [PATCH] style(memory): Some code style optimizations --- .../controllers/memory_storage_controller.py | 204 +++++++++--------- api/app/core/agent/langchain_agent.py | 6 +- api/app/core/memory/models/graph_models.py | 74 ++++--- .../embedding_generation.py | 3 +- api/app/repositories/neo4j/cypher_queries.py | 21 +- api/app/schemas/memory_storage_schema.py | 83 +++---- api/app/services/memory_storage_service.py | 158 +++++++------- api/app/tasks.py | 39 ++-- 8 files changed, 296 insertions(+), 292 deletions(-) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d91dfc36..d8b39325 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -54,8 +54,8 @@ router = APIRouter( @router.get("/info", response_model=ApiResponse) async def get_storage_info( - storage_id: str, - current_user: User = Depends(get_current_user) + storage_id: str, + current_user: User = Depends(get_current_user) ): """ Example wrapper endpoint - retrieves storage information @@ -75,24 +75,19 @@ async def get_storage_info( return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) - - - - - -@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 def create_config( - payload: ConfigParamsCreate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), + payload: ConfigParamsCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") try: # 将 workspace_id 注入到 payload 中(保持为 UUID 类型) @@ -107,9 +102,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {err_str}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) @@ -119,9 +116,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) @@ -129,10 +128,10 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( - config_id: UUID|int, - force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """删除记忆配置(带终端用户保护) @@ -145,24 +144,24 @@ def delete_config( force: 设置为 true 可强制删除(即使有终端用户正在使用) """ workspace_id = current_user.current_workspace_id - config_id=resolve_config_id(config_id, db) + config_id = resolve_config_id(config_id, db) # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"config_id={config_id}, force={force}" ) - + try: # 使用带保护的删除服务 from app.services.memory_config_service import MemoryConfigService - + config_service = MemoryConfigService(db) result = config_service.delete_config(config_id=config_id, force=force) - + if result["status"] == "error": api_logger.warning( f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" @@ -172,7 +171,7 @@ def delete_config( msg=result["message"], data={"config_id": str(config_id), "is_default": result.get("is_default", False)} ) - + if result["status"] == "warning": api_logger.warning( f"记忆配置正在使用,无法删除: config_id={config_id}, " @@ -186,7 +185,7 @@ def delete_config( "force_required": result["force_required"] } ) - + api_logger.info( f"记忆配置删除成功: config_id={config_id}, " f"affected_users={result['affected_users']}" @@ -195,7 +194,7 @@ def delete_config( msg=result["message"], data={"affected_users": result["affected_users"]} ) - + except Exception as e: api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) @@ -203,9 +202,9 @@ def delete_config( @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc def update_config( - payload: ConfigUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -213,12 +212,13 @@ def update_config( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + # 校验至少有一个字段需要更新 if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") - return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") - + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", + "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -231,9 +231,9 @@ def update_config( @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 def update_config_extracted( - payload: ConfigUpdateExtracted, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdateExtracted, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -241,7 +241,7 @@ def update_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -256,11 +256,11 @@ def update_config_extracted( # 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config -@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 +@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( - config_id: UUID | int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id config_id = resolve_config_id(config_id, db) @@ -268,7 +268,7 @@ def read_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") try: svc = DataConfigService(db) @@ -278,18 +278,19 @@ def read_config_extracted( api_logger.error(f"Read config extracted failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) -@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 + +@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 def read_all_config( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") try: svc = DataConfigService(db) @@ -303,14 +304,14 @@ def read_all_config( @router.post("/pilot_run", response_model=None) async def pilot_run( - payload: ConfigPilotRun, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigPilotRun, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> StreamingResponse: # 使用集中化的语言校验 language = get_language_from_header(language_type) - + api_logger.info( f"Pilot run requested: config_id={payload.config_id}, " f"dialogue_text_length={len(payload.dialogue_text)}, " @@ -333,9 +334,9 @@ async def pilot_run( @router.get("/search/kb_type_distribution", response_model=ApiResponse) async def get_kb_type_distribution( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") try: result = await kb_type_distribution(end_user_id) @@ -344,12 +345,12 @@ async def get_kb_type_distribution( api_logger.error(f"KB type distribution failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) - + @router.get("/search/dialogue", response_model=ApiResponse) async def search_dialogues_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") try: result = await search_dialogue(end_user_id) @@ -361,9 +362,9 @@ async def search_dialogues_num( @router.get("/search/chunk", response_model=ApiResponse) async def search_chunks_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") try: result = await search_chunk(end_user_id) @@ -375,9 +376,9 @@ async def search_chunks_num( @router.get("/search/statement", response_model=ApiResponse) async def search_statements_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") try: result = await search_statement(end_user_id) @@ -389,9 +390,9 @@ async def search_statements_num( @router.get("/search/entity", response_model=ApiResponse) async def search_entities_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") try: result = await search_entity(end_user_id) @@ -403,9 +404,9 @@ async def search_entities_num( @router.get("/search", response_model=ApiResponse) async def search_all_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: result = await search_all(end_user_id) @@ -417,9 +418,9 @@ async def search_all_num( @router.get("/search/detials", response_model=ApiResponse) async def search_entities_detials( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search details requested for end_user_id: {end_user_id}") try: result = await search_detials(end_user_id) @@ -431,9 +432,9 @@ async def search_entities_detials( @router.get("/search/edges", response_model=ApiResponse) async def search_entity_edges( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") try: result = await search_edges(end_user_id) @@ -443,14 +444,12 @@ async def search_entity_edges( return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) - - @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) async def get_hot_memory_tags_api( - limit: int = 10, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - ) -> dict: + limit: int = 10, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> dict: """ 获取热门记忆标签(带Redis缓存) @@ -461,18 +460,18 @@ async def get_hot_memory_tags_api( - 缓存未命中:~600-800ms(取决于LLM速度) """ workspace_id = current_user.current_workspace_id - + # 构建缓存键 cache_key = f"hot_memory_tags:{workspace_id}:{limit}" - + api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") - + try: # 尝试从Redis缓存获取 import json from app.aioRedis import aio_redis_get, aio_redis_set - + cached_result = await aio_redis_get(cache_key) if cached_result: api_logger.info(f"Cache hit for key: {cache_key}") @@ -481,11 +480,11 @@ async def get_hot_memory_tags_api( return success(data=data, msg="查询成功(缓存)") except json.JSONDecodeError: api_logger.warning(f"Failed to parse cached data, will refresh") - + # 缓存未命中,执行查询 api_logger.info(f"Cache miss for key: {cache_key}, executing query") result = await analytics_hot_memory_tags(db, current_user, limit) - + # 写入缓存(过期时间:5分钟) # 注意:result是列表,需要转换为JSON字符串 try: @@ -495,9 +494,9 @@ async def get_hot_memory_tags_api( except Exception as cache_error: # 缓存写入失败不影响主流程 api_logger.warning(f"Failed to cache result: {str(cache_error)}") - + return success(data=result, msg="查询成功") - + except Exception as e: api_logger.error(f"Hot memory tags failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) @@ -505,8 +504,8 @@ async def get_hot_memory_tags_api( @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) async def clear_hot_memory_tags_cache( - current_user: User = Depends(get_current_user), - ) -> dict: + current_user: User = Depends(get_current_user), +) -> dict: """ 清除热门标签缓存 @@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache( - 数据更新后立即生效 """ workspace_id = current_user.current_workspace_id - + api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") - + try: from app.aioRedis import aio_redis_delete - + # 清除所有limit的缓存(常见的limit值) cleared_count = 0 for limit in [5, 10, 15, 20, 30, 50]: @@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache( if result: cleared_count += 1 api_logger.info(f"Cleared cache for key: {cache_key}") - + return success( - data={"cleared_count": cleared_count}, + data={"cleared_count": cleared_count}, msg=f"成功清除 {cleared_count} 个缓存" ) - + except Exception as e: api_logger.error(f"Clear cache failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) @@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( - current_user: User = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> dict: workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") @@ -553,4 +552,3 @@ async def get_recent_activity_stats_api( except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 88b6371c..464a668a 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -598,8 +598,10 @@ class LangChainAgent: for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", - 0) if response_meta else 0 + total_tokens = response_meta.get("token_usage", {}).get( + "total_tokens", + 0 + ) if response_meta else 0 yield total_tokens break if memory_flag: diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1880b9ab..5a2d8c2e 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -44,21 +44,21 @@ def parse_historical_datetime(v): """ if v is None: return v - + # 处理 Neo4j DateTime 对象 if hasattr(v, 'to_native'): return v.to_native() - + # 处理 Python datetime 对象 if isinstance(v, datetime): return v - + if isinstance(v, str): # 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 支持1-4位年份 pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' match = re.match(pattern, v) - + if match: try: year = int(match.group(1)) @@ -68,31 +68,31 @@ def parse_historical_datetime(v): minute = int(match.group(5)) if match.group(5) else 0 second = int(match.group(6)) if match.group(6) else 0 microsecond = 0 - + # 处理微秒 if match.group(7): # 补齐或截断到6位 us_str = match.group(7).ljust(6, '0')[:6] microsecond = int(us_str) - + # 处理时区 tzinfo = None if 'Z' in v or match.group(8): tzinfo = timezone.utc - + # 创建 datetime 对象 return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) - + except (ValueError, OverflowError): # 日期值无效(如月份13、日期32等) return None - + # 如果不匹配模式,尝试使用 fromisoformat(用于标准格式) try: return datetime.fromisoformat(v.replace('Z', '+00:00')) except Exception: return None - + return v @@ -167,7 +167,7 @@ class EntityEntityEdge(Edge): source_statement_id: str = Field(..., description="Statement where this relationship was extracted") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): @@ -206,7 +206,8 @@ class DialogueNode(Node): ref_id: str = Field(..., description="Reference identifier of the dialog") content: str = Field(..., description="Dialogue content") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this dialogue (integer or string)") class StatementNode(Node): @@ -241,17 +242,17 @@ class StatementNode(Node): chunk_id: str = Field(..., description="ID of the parent chunk") stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") - + # Speaker identification speaker: Optional[str] = Field( None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" ) - + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( - None, - ge=0.0, + None, + ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0 (displayed on node)" ) @@ -264,25 +265,26 @@ class StatementNode(Node): description="Emotion subject: self/other/object" ) emotion_type: Optional[str] = Field( - None, + None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral" ) emotion_keywords: Optional[List[str]] = Field( default_factory=list, description="Emotion keywords list, max 3 items" ) - + # Temporal fields temporal_info: TemporalInfo = Field(..., description="Temporal information") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + # Embedding and other fields statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this statement (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -309,13 +311,13 @@ class StatementNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): """使用通用的历史日期解析函数""" return parse_historical_datetime(v) - + @field_validator('emotion_type', mode='before') @classmethod def validate_emotion_type(cls, v): @@ -326,7 +328,7 @@ class StatementNode(Node): if v not in valid_types: raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") return v - + @field_validator('emotion_subject', mode='before') @classmethod def validate_emotion_subject(cls, v): @@ -337,7 +339,7 @@ class StatementNode(Node): if v not in valid_subjects: raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") return v - + @field_validator('emotion_keywords', mode='before') @classmethod def validate_emotion_keywords(cls, v): @@ -405,19 +407,20 @@ class ExtractedEntityNode(Node): entity_type: str = Field(..., description="Type of the entity") description: str = Field(..., description="Entity description") example: str = Field( - default="", + default="", description="A concise example (around 20 characters) to help understand the entity" ) aliases: List[str] = Field( - default_factory=list, + default_factory=list, description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this entity (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -444,16 +447,16 @@ class ExtractedEntityNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + # Explicit Memory Classification is_explicit_memory: bool = Field( default=False, description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" ) - + @field_validator('aliases', mode='before') @classmethod - def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 + def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 """Validate and clean aliases field using utility function. This validator ensures that the aliases field is always a valid list of strings. @@ -507,8 +510,9 @@ class MemorySummaryNode(Node): memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this summary (integer or string)") + # ACT-R Forgetting Engine Properties original_statement_id: Optional[str] = Field( None, @@ -522,7 +526,7 @@ class MemorySummaryNode(Node): None, description="Timestamp when the nodes were merged" ) - + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..a0bccc25 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -227,7 +227,8 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + print(f"实体 '{entity_texts[i]}' " + f"嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 5b2e5f1e..0ac7dcb1 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id, RETURN elementId(r) AS uuid """ - # Entity Merge Query MERGE_ENTITIES = """ MATCH (canonical:ExtractedEntity {id: $canonical_id}) @@ -829,9 +828,8 @@ neo4j_query_all = """ other as entity2 """ - '''针对当前节点下扩长的句子,实体和总结''' -Memory_Timeline_ExtractedEntity=""" +Memory_Timeline_ExtractedEntity = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) = $id AND (ms:ExtractedEntity OR ms:MemorySummary) @@ -869,7 +867,7 @@ RETURN """ -Memory_Timeline_MemorySummary=""" +Memory_Timeline_MemorySummary = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) =$id AND (ms:MemorySummary OR ms:ExtractedEntity) @@ -904,7 +902,7 @@ RETURN } ) AS statement; """ -Memory_Timeline_Statement=""" +Memory_Timeline_Statement = """ MATCH (n) WHERE elementId(n) = $id @@ -947,7 +945,7 @@ RETURN """ '''针对当前节点,主要获取更加完整的句子节点''' -Memory_Space_Emotion_Statement=""" +Memory_Space_Emotion_Statement = """ MATCH (n) WHERE elementId(n) = $id RETURN @@ -957,7 +955,7 @@ RETURN n.statement AS statement; """ -Memory_Space_Emotion_MemorySummary=""" +Memory_Space_Emotion_MemorySummary = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -970,7 +968,7 @@ RETURN DISTINCT e.emotion_type AS emotion_type, e.statement AS statement; """ -Memory_Space_Emotion_ExtractedEntity=""" +Memory_Space_Emotion_ExtractedEntity = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -985,18 +983,18 @@ RETURN DISTINCT '''获取实体''' -Memory_Space_User=""" +Memory_Space_User = """ MATCH (n)-[r]->(m) WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ -Memory_Space_Entity=""" +Memory_Space_Entity = """ MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN DISTINCT m.name as name,m.end_user_id as end_user_id """ -Memory_Space_Associative=""" +Memory_Space_Associative = """ MATCH (u)-[]-(x)-[]-(h) WHERE elementId(u) = $user_id AND elementId(h) = $id @@ -1060,7 +1058,6 @@ Graph_Node_query = """ """ - # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 # ============================================================ diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 046b79e7..c8abbc46 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -8,9 +8,6 @@ import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator - - - # ============================================================================ # 从 json_schema.py 迁移的 Schema # ============================================================================ @@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel): class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") + data: List[BaseDataSchema] = Field(..., + description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") - memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, + description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, + description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") def _normalize_data(cls, v): @@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel): - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( - ..., + ..., description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") - resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") - change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, + description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, + description="List of detailed change records with IDs and field information.") class SingleReflexionResultSchema(BaseModel): @@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel): resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" - results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") + results: List[SingleReflexionResultSchema] = Field(..., + description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") def _normalize_resolved(cls, v): @@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel): # Composite key identifying a config row class ConfigKey(BaseModel): # 配置参数键模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") - user_id: str = Field("user_id", description="用户标识(字符串)") - apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") + user_id: str | None = Field(default=None, description="用户标识(字符串)") + apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)") # Allowed chunking strategies (extendable later) @@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_name: str = Field("配置名称", description="配置名称(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") - + # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") - + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") - + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") + + class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 @@ -255,7 +262,7 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") @@ -322,14 +329,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 # 遗忘引擎配置参数更新模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -364,11 +371,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) def fail( - msg: str, - error_code: str = "ERROR", - data: Optional[Any] = None, - time: Optional[int] = None, - query_preview: Optional[str] = None, + msg: str, + error_code: str = "ERROR", + data: Optional[Any] = None, + time: Optional[int] = None, + query_preview: Optional[str] = None, ) -> ApiResponse: payload = data if query_preview is not None: @@ -387,12 +394,13 @@ def fail( time=time or _now_ms(), ) + class GenerateCacheRequest(BaseModel): """缓存生成请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: Optional[str] = Field( - None, + None, description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" ) @@ -404,7 +412,7 @@ class GenerateCacheRequest(BaseModel): class ForgettingTriggerRequest(BaseModel): """手动触发遗忘周期请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") @@ -413,7 +421,7 @@ class ForgettingTriggerRequest(BaseModel): class ForgettingConfigResponse(BaseModel): """遗忘引擎配置响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") decay_constant: float = Field(..., description="衰减常数 d") lambda_time: float = Field(..., description="时间衰减参数") @@ -432,7 +440,7 @@ class ForgettingConfigUpdateRequest(BaseModel): """遗忘引擎配置更新请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") @@ -448,7 +456,7 @@ class ForgettingConfigUpdateRequest(BaseModel): class ForgettingCycleHistoryPoint(BaseModel): """遗忘周期历史数据点模型(用于趋势图)""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + date: str = Field(..., description="日期(格式: '1/1', '1/2')") merged_count: int = Field(..., description="每日融合节点数") average_activation: Optional[float] = Field(None, description="平均激活值") @@ -459,7 +467,7 @@ class ForgettingCycleHistoryPoint(BaseModel): class PendingForgettingNode(BaseModel): """待遗忘节点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + node_id: str = Field(..., description="节点ID") node_type: str = Field(..., description="节点类型:statement/entity/summary") content_summary: str = Field(..., description="内容摘要") @@ -472,7 +480,8 @@ class ForgettingStatsResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="forbid") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") node_distribution: Dict[str, int] = Field(..., description="节点类型分布") - recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") + recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., + description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") @@ -480,7 +489,7 @@ class ForgettingStatsResponse(BaseModel): class ForgettingReportResponse(BaseModel): """遗忘周期报告响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + merged_count: int = Field(..., description="融合的节点对数量") nodes_before: int = Field(..., description="遗忘前的节点总数") nodes_after: int = Field(..., description="遗忘后的节点总数") @@ -495,7 +504,7 @@ class ForgettingReportResponse(BaseModel): class ForgettingCurvePoint(BaseModel): """遗忘曲线数据点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + day: int = Field(..., description="天数") activation: float = Field(..., description="激活值") retention_rate: float = Field(..., description="保持率(与激活值相同)") @@ -504,7 +513,7 @@ class ForgettingCurvePoint(BaseModel): class ForgettingCurveRequest(BaseModel): """遗忘曲线请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") @@ -513,6 +522,6 @@ class ForgettingCurveRequest(BaseModel): class ForgettingCurveResponse(BaseModel): """遗忘曲线响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") config: Dict[str, Any] = Field(..., description="使用的配置参数") diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 6e7c1ad4..264ae4df 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -11,9 +11,11 @@ import time from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional +from dotenv import load_dotenv +from sqlalchemy.orm import Session + from app.core.logging_config import get_config_logger, get_logger from app.core.memory.analytics.hot_memory_tags import ( - get_hot_memory_tags, get_raw_tags_from_db, filter_tags_with_llm, ) @@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import ( ) from app.services.memory_config_service import MemoryConfigService from app.utils.sse_utils import format_sse_message -from dotenv import load_dotenv -from sqlalchemy.orm import Session logger = get_logger(__name__) config_logger = get_config_logger() @@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector() class MemoryStorageService: """Service for memory storage operations""" - + def __init__(self): logger.info("MemoryStorageService initialized") - + async def get_storage_info(self) -> dict: """ Example wrapper method - retrieves storage information @@ -59,17 +59,17 @@ class MemoryStorageService: Storage information dictionary """ logger.info("Getting storage info ") - + # Empty wrapper - implement your logic here result = { "status": "active", "message": "This is an example wrapper" } - - return result - -class DataConfigService: # 数据配置服务类(PostgreSQL) + return result + + +class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. 使用 SQLAlchemy ORM 进行数据库操作。 @@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return data_list # --- Create --- - def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) # 业务层检查同一工作空间下是否已存在同名配置 if params.workspace_id and params.config_name: from app.models.memory_config_model import MemoryConfig @@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return None # --- Delete --- - def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) + def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) if not success: raise ValueError("未找到配置") return {"affected": 1} # --- Update --- - def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 + def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 config = MemoryConfigRepository.update(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} - def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 + def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 config = MemoryConfigRepository.update_extracted(self.db, update) if not config: raise ValueError("未找到配置") @@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # --- Read --- - def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 + def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) if not result: raise ValueError("未找到配置") return result # --- Read All --- - def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 + def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) # 检查并修正 pruning_scene 与 scene_name 不一致的记录 @@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) except (ValueError, TypeError): config_id_old = None - if config_id_old: - memory_config=config_id_old + memory_config = config_id_old else: - memory_config=config.config_id + memory_config = config.config_id config_dict = { "config_id": memory_config, "config_name": config.config_name, @@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 return self._convert_timestamps_to_format(data_list) - async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: """ 流式执行试运行,产生 SSE 格式的进度事件 @@ -311,14 +309,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) """ from pathlib import Path project_root = str(Path(__file__).resolve().parents[2]) - + try: # 发出初始进度事件 yield format_sse_message("starting", { "message": "开始试运行...", "time": int(time.time() * 1000) }) - + # 步骤 1: 配置加载和验证(数据库优先) payload_cid = str(getattr(payload, "config_id", "") or "").strip() cid: Optional[str] = payload_cid if payload_cid else None @@ -344,27 +342,28 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 关联了本体场景,优先使用 custom_text if hasattr(payload, 'custom_text') and payload.custom_text: dialogue_text = payload.custom_text.strip() - logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") + logger.info( + f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") else: # 如果没有提供 custom_text,回退到 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") + logger.info( + f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") else: # 没有关联本体场景,使用 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}") - + # 验证最终使用的文本不为空 if not dialogue_text: raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)") - - logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") + logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") # 步骤 2: 创建进度回调函数捕获管线进度 # 使用队列在回调和生成器之间传递进度事件 progress_queue: asyncio.Queue = asyncio.Queue() - + async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: """ 进度回调函数,将进度事件放入队列 @@ -375,14 +374,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) data: 可选的结果数据(用于传递节点执行结果) """ await progress_queue.put((stage, message, data)) - + # 步骤 3: 在后台任务中执行管线 async def run_pipeline(): """在后台执行管线并捕获异常""" try: from app.services.pilot_run_service import run_pilot_extraction - - logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") + + logger.info( + f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") await run_pilot_extraction( memory_config=memory_config, dialogue_text=dialogue_text, @@ -391,60 +391,60 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) language=language, ) logger.info("[PILOT_RUN_STREAM] pipeline_main completed") - + # 标记管线完成 await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) except Exception as e: # 将异常放入队列 await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) - + # 启动后台任务 pipeline_task = asyncio.create_task(run_pipeline()) - + # 步骤 4: 从队列中读取进度事件并发出 while True: try: # 等待进度事件,设置超时以检测客户端断开 stage, message, data = await asyncio.wait_for( - progress_queue.get(), + progress_queue.get(), timeout=0.5 ) - + # 检查特殊标记 if stage == "__PIPELINE_COMPLETE__": break elif stage == "__PIPELINE_ERROR__": raise RuntimeError(message) - + # 构建进度事件数据 progress_data = { "message": message, "time": int(time.time() * 1000) } - + # 如果有结果数据,添加到事件中 if data: progress_data["data"] = data - + # 发出进度事件,使用 stage 作为事件类型 yield format_sse_message(stage, progress_data) - + except TimeoutError: # 超时,继续等待(这允许检测客户端断开) continue - + # 等待管线任务完成 await pipeline_task - + # 步骤 5: 读取提取结果 from app.core.config import settings result_path = settings.get_memory_output_path("extracted_result.json") if not os.path.isfile(result_path): raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - + with open(result_path, "r", encoding="utf-8") as rf: extracted_result = json.load(rf) - + # 步骤 6: 计算本体覆盖率并合并到结果中 result_data = { "config_id": cid, @@ -460,15 +460,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) result_data["ontology_coverage"] = ontology_coverage except Exception as cov_err: logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) - + yield format_sse_message("result", result_data) - + # 步骤 7: 发出完成事件 yield format_sse_message("done", { "message": "试运行完成", "time": int(time.time() * 1000) }) - + except asyncio.CancelledError: # 客户端断开连接 logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") @@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "time": int(time.time() * 1000) }) - async def _compute_ontology_coverage( - self, - extracted_result: Dict[str, Any], - memory_config, + self, + extracted_result: Dict[str, Any], + memory_config, ) -> Optional[Dict[str, Any]]: """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 @@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # Ensure env for connector (e.g., NEO4J_PASSWORD) -load_dotenv() -_neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: @@ -664,7 +661,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A # 检查结果是否为空或长度不足 if not result or len(result) < 4: data = { - "total": 0, + "total": 0, "distribution": [ {"type": "dialogue", "count": 0}, {"type": "chunk", "count": 0}, @@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] ) return result + async def analytics_hot_memory_tags( - db: Session, - current_user: User, - limit: int = 10 + db: Session, + current_user: User, + limit: int = 10 ) -> List[Dict[str, Any]]: """ 获取热门记忆标签,按数量排序并返回前N个 @@ -721,27 +719,27 @@ async def analytics_hot_memory_tags( from app.services.memory_dashboard_service import get_workspace_end_users # 使用 asyncio.to_thread 避免阻塞事件循环 end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) - + if not end_users: return [] - + # 步骤1: 收集所有用户的原始标签(不调用LLM) connector = Neo4jConnector() try: all_raw_tags = [] for end_user in end_users: raw_tags = await get_raw_tags_from_db( - connector, - str(end_user.id), - limit=raw_limit, + connector, + str(end_user.id), + limit=raw_limit, by_user=False ) if raw_tags: all_raw_tags.extend(raw_tags) - + if not all_raw_tags: return [] - + # 步骤2: 聚合相同标签的频率 tag_frequency_map = {} for tag_name, frequency in all_raw_tags: @@ -749,36 +747,36 @@ async def analytics_hot_memory_tags( tag_frequency_map[tag_name] += frequency else: tag_frequency_map[tag_name] = frequency - + # 步骤3: 按频率降序排序,取前raw_limit个 sorted_tags = sorted( - tag_frequency_map.items(), - key=lambda x: x[1], + tag_frequency_map.items(), + key=lambda x: x[1], reverse=True )[:raw_limit] - + if not sorted_tags: return [] - + # 步骤4: 只调用一次LLM进行筛选 tag_names = [tag for tag, _ in sorted_tags] - + # 使用第一个用户的end_user_id来获取LLM配置 # 因为同一工作空间下的用户应该使用相同的配置 first_end_user_id = str(end_users[0].id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) - + # 步骤5: 根据LLM筛选结果构建最终列表(保留频率) final_tags = [] for tag, freq in sorted_tags: if tag in filtered_tag_names: final_tags.append((tag, freq)) - + # 步骤6: 只返回前limit个 top_tags = final_tags[:limit] - + return [{"name": t, "frequency": f} for t, f in top_tags] - + finally: await connector.close() @@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> source = "log" total = ( - stats.get("chunk_count", 0) - + stats.get("statements_count", 0) - + stats.get("triplet_entities_count", 0) - + stats.get("triplet_relations_count", 0) - + stats.get("temporal_count", 0) + stats.get("chunk_count", 0) + + stats.get("statements_count", 0) + + stats.get("triplet_entities_count", 0) + + stats.get("triplet_relations_count", 0) + + stats.get("temporal_count", 0) ) # 计算"最新一次活动多久前"(仅日志来源时有效) @@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data - - diff --git a/api/app/tasks.py b/api/app/tasks.py index f5258330..c37e564e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, - user_rag_memory_id: str, - language: str = "zh") -> Dict[str, Any]: +def write_message_task( + self, + end_user_id: str, + message: list[dict], + config_id: str | int, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" +) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) @@ -1105,14 +1111,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100 * '-') - print(actual_config_id) - print(100 * '-') - logger.info( - f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} " + f"(type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error( - f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} " + f"(type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1151,8 +1154,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - logger.info( - f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + logger.info(f"[CELERY WRITE] Task completed successfully " + f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: @@ -1167,7 +1170,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s ) except Exception as _e: logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") - return { "status": "SUCCESS", "result": result, @@ -2672,7 +2674,7 @@ def write_perceptual_memory( ignore_result=False, max_retries=0, acks_late=False, - time_limit=7200, # 2小时硬超时 + time_limit=7200, # 2小时硬超时 soft_time_limit=6900, ) def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: @@ -2749,7 +2751,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s llm_model_id=llm_model_id, ) - logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") await engine.full_clustering(end_user_id) initialized += 1 logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") @@ -2772,12 +2775,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time