From c17a2dad2d3a7273b1d45783c8291a2d338bbb99 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:22:20 +0800 Subject: [PATCH 1/8] style(memory): Some code style optimizations --- .../controllers/memory_storage_controller.py | 204 +++++++++--------- api/app/core/agent/langchain_agent.py | 6 +- api/app/core/memory/models/graph_models.py | 74 ++++--- .../embedding_generation.py | 3 +- api/app/repositories/neo4j/cypher_queries.py | 21 +- api/app/schemas/memory_storage_schema.py | 83 +++---- api/app/services/memory_storage_service.py | 158 +++++++------- api/app/tasks.py | 39 ++-- 8 files changed, 296 insertions(+), 292 deletions(-) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d91dfc36..d8b39325 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -54,8 +54,8 @@ router = APIRouter( @router.get("/info", response_model=ApiResponse) async def get_storage_info( - storage_id: str, - current_user: User = Depends(get_current_user) + storage_id: str, + current_user: User = Depends(get_current_user) ): """ Example wrapper endpoint - retrieves storage information @@ -75,24 +75,19 @@ async def get_storage_info( return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) - - - - - -@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 def create_config( - payload: ConfigParamsCreate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), + payload: ConfigParamsCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") try: # 将 workspace_id 注入到 payload 中(保持为 UUID 类型) @@ -107,9 +102,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {err_str}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) @@ -119,9 +116,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) @@ -129,10 +128,10 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( - config_id: UUID|int, - force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """删除记忆配置(带终端用户保护) @@ -145,24 +144,24 @@ def delete_config( force: 设置为 true 可强制删除(即使有终端用户正在使用) """ workspace_id = current_user.current_workspace_id - config_id=resolve_config_id(config_id, db) + config_id = resolve_config_id(config_id, db) # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"config_id={config_id}, force={force}" ) - + try: # 使用带保护的删除服务 from app.services.memory_config_service import MemoryConfigService - + config_service = MemoryConfigService(db) result = config_service.delete_config(config_id=config_id, force=force) - + if result["status"] == "error": api_logger.warning( f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" @@ -172,7 +171,7 @@ def delete_config( msg=result["message"], data={"config_id": str(config_id), "is_default": result.get("is_default", False)} ) - + if result["status"] == "warning": api_logger.warning( f"记忆配置正在使用,无法删除: config_id={config_id}, " @@ -186,7 +185,7 @@ def delete_config( "force_required": result["force_required"] } ) - + api_logger.info( f"记忆配置删除成功: config_id={config_id}, " f"affected_users={result['affected_users']}" @@ -195,7 +194,7 @@ def delete_config( msg=result["message"], data={"affected_users": result["affected_users"]} ) - + except Exception as e: api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) @@ -203,9 +202,9 @@ def delete_config( @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc def update_config( - payload: ConfigUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -213,12 +212,13 @@ def update_config( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + # 校验至少有一个字段需要更新 if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") - return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") - + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", + "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -231,9 +231,9 @@ def update_config( @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 def update_config_extracted( - payload: ConfigUpdateExtracted, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdateExtracted, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -241,7 +241,7 @@ def update_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -256,11 +256,11 @@ def update_config_extracted( # 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config -@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 +@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( - config_id: UUID | int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id config_id = resolve_config_id(config_id, db) @@ -268,7 +268,7 @@ def read_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") try: svc = DataConfigService(db) @@ -278,18 +278,19 @@ def read_config_extracted( api_logger.error(f"Read config extracted failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) -@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 + +@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 def read_all_config( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") try: svc = DataConfigService(db) @@ -303,14 +304,14 @@ def read_all_config( @router.post("/pilot_run", response_model=None) async def pilot_run( - payload: ConfigPilotRun, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigPilotRun, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> StreamingResponse: # 使用集中化的语言校验 language = get_language_from_header(language_type) - + api_logger.info( f"Pilot run requested: config_id={payload.config_id}, " f"dialogue_text_length={len(payload.dialogue_text)}, " @@ -333,9 +334,9 @@ async def pilot_run( @router.get("/search/kb_type_distribution", response_model=ApiResponse) async def get_kb_type_distribution( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") try: result = await kb_type_distribution(end_user_id) @@ -344,12 +345,12 @@ async def get_kb_type_distribution( api_logger.error(f"KB type distribution failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) - + @router.get("/search/dialogue", response_model=ApiResponse) async def search_dialogues_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") try: result = await search_dialogue(end_user_id) @@ -361,9 +362,9 @@ async def search_dialogues_num( @router.get("/search/chunk", response_model=ApiResponse) async def search_chunks_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") try: result = await search_chunk(end_user_id) @@ -375,9 +376,9 @@ async def search_chunks_num( @router.get("/search/statement", response_model=ApiResponse) async def search_statements_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") try: result = await search_statement(end_user_id) @@ -389,9 +390,9 @@ async def search_statements_num( @router.get("/search/entity", response_model=ApiResponse) async def search_entities_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") try: result = await search_entity(end_user_id) @@ -403,9 +404,9 @@ async def search_entities_num( @router.get("/search", response_model=ApiResponse) async def search_all_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: result = await search_all(end_user_id) @@ -417,9 +418,9 @@ async def search_all_num( @router.get("/search/detials", response_model=ApiResponse) async def search_entities_detials( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search details requested for end_user_id: {end_user_id}") try: result = await search_detials(end_user_id) @@ -431,9 +432,9 @@ async def search_entities_detials( @router.get("/search/edges", response_model=ApiResponse) async def search_entity_edges( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") try: result = await search_edges(end_user_id) @@ -443,14 +444,12 @@ async def search_entity_edges( return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) - - @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) async def get_hot_memory_tags_api( - limit: int = 10, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - ) -> dict: + limit: int = 10, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> dict: """ 获取热门记忆标签(带Redis缓存) @@ -461,18 +460,18 @@ async def get_hot_memory_tags_api( - 缓存未命中:~600-800ms(取决于LLM速度) """ workspace_id = current_user.current_workspace_id - + # 构建缓存键 cache_key = f"hot_memory_tags:{workspace_id}:{limit}" - + api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") - + try: # 尝试从Redis缓存获取 import json from app.aioRedis import aio_redis_get, aio_redis_set - + cached_result = await aio_redis_get(cache_key) if cached_result: api_logger.info(f"Cache hit for key: {cache_key}") @@ -481,11 +480,11 @@ async def get_hot_memory_tags_api( return success(data=data, msg="查询成功(缓存)") except json.JSONDecodeError: api_logger.warning(f"Failed to parse cached data, will refresh") - + # 缓存未命中,执行查询 api_logger.info(f"Cache miss for key: {cache_key}, executing query") result = await analytics_hot_memory_tags(db, current_user, limit) - + # 写入缓存(过期时间:5分钟) # 注意:result是列表,需要转换为JSON字符串 try: @@ -495,9 +494,9 @@ async def get_hot_memory_tags_api( except Exception as cache_error: # 缓存写入失败不影响主流程 api_logger.warning(f"Failed to cache result: {str(cache_error)}") - + return success(data=result, msg="查询成功") - + except Exception as e: api_logger.error(f"Hot memory tags failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) @@ -505,8 +504,8 @@ async def get_hot_memory_tags_api( @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) async def clear_hot_memory_tags_cache( - current_user: User = Depends(get_current_user), - ) -> dict: + current_user: User = Depends(get_current_user), +) -> dict: """ 清除热门标签缓存 @@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache( - 数据更新后立即生效 """ workspace_id = current_user.current_workspace_id - + api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") - + try: from app.aioRedis import aio_redis_delete - + # 清除所有limit的缓存(常见的limit值) cleared_count = 0 for limit in [5, 10, 15, 20, 30, 50]: @@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache( if result: cleared_count += 1 api_logger.info(f"Cleared cache for key: {cache_key}") - + return success( - data={"cleared_count": cleared_count}, + data={"cleared_count": cleared_count}, msg=f"成功清除 {cleared_count} 个缓存" ) - + except Exception as e: api_logger.error(f"Clear cache failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) @@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( - current_user: User = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> dict: workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") @@ -553,4 +552,3 @@ async def get_recent_activity_stats_api( except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 88b6371c..464a668a 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -598,8 +598,10 @@ class LangChainAgent: for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", - 0) if response_meta else 0 + total_tokens = response_meta.get("token_usage", {}).get( + "total_tokens", + 0 + ) if response_meta else 0 yield total_tokens break if memory_flag: diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1880b9ab..5a2d8c2e 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -44,21 +44,21 @@ def parse_historical_datetime(v): """ if v is None: return v - + # 处理 Neo4j DateTime 对象 if hasattr(v, 'to_native'): return v.to_native() - + # 处理 Python datetime 对象 if isinstance(v, datetime): return v - + if isinstance(v, str): # 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 支持1-4位年份 pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' match = re.match(pattern, v) - + if match: try: year = int(match.group(1)) @@ -68,31 +68,31 @@ def parse_historical_datetime(v): minute = int(match.group(5)) if match.group(5) else 0 second = int(match.group(6)) if match.group(6) else 0 microsecond = 0 - + # 处理微秒 if match.group(7): # 补齐或截断到6位 us_str = match.group(7).ljust(6, '0')[:6] microsecond = int(us_str) - + # 处理时区 tzinfo = None if 'Z' in v or match.group(8): tzinfo = timezone.utc - + # 创建 datetime 对象 return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) - + except (ValueError, OverflowError): # 日期值无效(如月份13、日期32等) return None - + # 如果不匹配模式,尝试使用 fromisoformat(用于标准格式) try: return datetime.fromisoformat(v.replace('Z', '+00:00')) except Exception: return None - + return v @@ -167,7 +167,7 @@ class EntityEntityEdge(Edge): source_statement_id: str = Field(..., description="Statement where this relationship was extracted") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): @@ -206,7 +206,8 @@ class DialogueNode(Node): ref_id: str = Field(..., description="Reference identifier of the dialog") content: str = Field(..., description="Dialogue content") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this dialogue (integer or string)") class StatementNode(Node): @@ -241,17 +242,17 @@ class StatementNode(Node): chunk_id: str = Field(..., description="ID of the parent chunk") stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") - + # Speaker identification speaker: Optional[str] = Field( None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" ) - + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( - None, - ge=0.0, + None, + ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0 (displayed on node)" ) @@ -264,25 +265,26 @@ class StatementNode(Node): description="Emotion subject: self/other/object" ) emotion_type: Optional[str] = Field( - None, + None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral" ) emotion_keywords: Optional[List[str]] = Field( default_factory=list, description="Emotion keywords list, max 3 items" ) - + # Temporal fields temporal_info: TemporalInfo = Field(..., description="Temporal information") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + # Embedding and other fields statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this statement (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -309,13 +311,13 @@ class StatementNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): """使用通用的历史日期解析函数""" return parse_historical_datetime(v) - + @field_validator('emotion_type', mode='before') @classmethod def validate_emotion_type(cls, v): @@ -326,7 +328,7 @@ class StatementNode(Node): if v not in valid_types: raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") return v - + @field_validator('emotion_subject', mode='before') @classmethod def validate_emotion_subject(cls, v): @@ -337,7 +339,7 @@ class StatementNode(Node): if v not in valid_subjects: raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") return v - + @field_validator('emotion_keywords', mode='before') @classmethod def validate_emotion_keywords(cls, v): @@ -405,19 +407,20 @@ class ExtractedEntityNode(Node): entity_type: str = Field(..., description="Type of the entity") description: str = Field(..., description="Entity description") example: str = Field( - default="", + default="", description="A concise example (around 20 characters) to help understand the entity" ) aliases: List[str] = Field( - default_factory=list, + default_factory=list, description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this entity (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -444,16 +447,16 @@ class ExtractedEntityNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + # Explicit Memory Classification is_explicit_memory: bool = Field( default=False, description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" ) - + @field_validator('aliases', mode='before') @classmethod - def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 + def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 """Validate and clean aliases field using utility function. This validator ensures that the aliases field is always a valid list of strings. @@ -507,8 +510,9 @@ class MemorySummaryNode(Node): memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this summary (integer or string)") + # ACT-R Forgetting Engine Properties original_statement_id: Optional[str] = Field( None, @@ -522,7 +526,7 @@ class MemorySummaryNode(Node): None, description="Timestamp when the nodes were merged" ) - + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..a0bccc25 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -227,7 +227,8 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + print(f"实体 '{entity_texts[i]}' " + f"嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 5b2e5f1e..0ac7dcb1 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id, RETURN elementId(r) AS uuid """ - # Entity Merge Query MERGE_ENTITIES = """ MATCH (canonical:ExtractedEntity {id: $canonical_id}) @@ -829,9 +828,8 @@ neo4j_query_all = """ other as entity2 """ - '''针对当前节点下扩长的句子,实体和总结''' -Memory_Timeline_ExtractedEntity=""" +Memory_Timeline_ExtractedEntity = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) = $id AND (ms:ExtractedEntity OR ms:MemorySummary) @@ -869,7 +867,7 @@ RETURN """ -Memory_Timeline_MemorySummary=""" +Memory_Timeline_MemorySummary = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) =$id AND (ms:MemorySummary OR ms:ExtractedEntity) @@ -904,7 +902,7 @@ RETURN } ) AS statement; """ -Memory_Timeline_Statement=""" +Memory_Timeline_Statement = """ MATCH (n) WHERE elementId(n) = $id @@ -947,7 +945,7 @@ RETURN """ '''针对当前节点,主要获取更加完整的句子节点''' -Memory_Space_Emotion_Statement=""" +Memory_Space_Emotion_Statement = """ MATCH (n) WHERE elementId(n) = $id RETURN @@ -957,7 +955,7 @@ RETURN n.statement AS statement; """ -Memory_Space_Emotion_MemorySummary=""" +Memory_Space_Emotion_MemorySummary = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -970,7 +968,7 @@ RETURN DISTINCT e.emotion_type AS emotion_type, e.statement AS statement; """ -Memory_Space_Emotion_ExtractedEntity=""" +Memory_Space_Emotion_ExtractedEntity = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -985,18 +983,18 @@ RETURN DISTINCT '''获取实体''' -Memory_Space_User=""" +Memory_Space_User = """ MATCH (n)-[r]->(m) WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ -Memory_Space_Entity=""" +Memory_Space_Entity = """ MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN DISTINCT m.name as name,m.end_user_id as end_user_id """ -Memory_Space_Associative=""" +Memory_Space_Associative = """ MATCH (u)-[]-(x)-[]-(h) WHERE elementId(u) = $user_id AND elementId(h) = $id @@ -1060,7 +1058,6 @@ Graph_Node_query = """ """ - # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 # ============================================================ diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 046b79e7..c8abbc46 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -8,9 +8,6 @@ import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator - - - # ============================================================================ # 从 json_schema.py 迁移的 Schema # ============================================================================ @@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel): class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") + data: List[BaseDataSchema] = Field(..., + description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") - memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, + description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, + description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") def _normalize_data(cls, v): @@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel): - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( - ..., + ..., description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") - resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") - change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, + description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, + description="List of detailed change records with IDs and field information.") class SingleReflexionResultSchema(BaseModel): @@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel): resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" - results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") + results: List[SingleReflexionResultSchema] = Field(..., + description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") def _normalize_resolved(cls, v): @@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel): # Composite key identifying a config row class ConfigKey(BaseModel): # 配置参数键模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") - user_id: str = Field("user_id", description="用户标识(字符串)") - apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") + user_id: str | None = Field(default=None, description="用户标识(字符串)") + apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)") # Allowed chunking strategies (extendable later) @@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_name: str = Field("配置名称", description="配置名称(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") - + # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") - + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") - + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") + + class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 @@ -255,7 +262,7 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") @@ -322,14 +329,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 # 遗忘引擎配置参数更新模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -364,11 +371,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) def fail( - msg: str, - error_code: str = "ERROR", - data: Optional[Any] = None, - time: Optional[int] = None, - query_preview: Optional[str] = None, + msg: str, + error_code: str = "ERROR", + data: Optional[Any] = None, + time: Optional[int] = None, + query_preview: Optional[str] = None, ) -> ApiResponse: payload = data if query_preview is not None: @@ -387,12 +394,13 @@ def fail( time=time or _now_ms(), ) + class GenerateCacheRequest(BaseModel): """缓存生成请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: Optional[str] = Field( - None, + None, description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" ) @@ -404,7 +412,7 @@ class GenerateCacheRequest(BaseModel): class ForgettingTriggerRequest(BaseModel): """手动触发遗忘周期请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") @@ -413,7 +421,7 @@ class ForgettingTriggerRequest(BaseModel): class ForgettingConfigResponse(BaseModel): """遗忘引擎配置响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") decay_constant: float = Field(..., description="衰减常数 d") lambda_time: float = Field(..., description="时间衰减参数") @@ -432,7 +440,7 @@ class ForgettingConfigUpdateRequest(BaseModel): """遗忘引擎配置更新请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") @@ -448,7 +456,7 @@ class ForgettingConfigUpdateRequest(BaseModel): class ForgettingCycleHistoryPoint(BaseModel): """遗忘周期历史数据点模型(用于趋势图)""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + date: str = Field(..., description="日期(格式: '1/1', '1/2')") merged_count: int = Field(..., description="每日融合节点数") average_activation: Optional[float] = Field(None, description="平均激活值") @@ -459,7 +467,7 @@ class ForgettingCycleHistoryPoint(BaseModel): class PendingForgettingNode(BaseModel): """待遗忘节点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + node_id: str = Field(..., description="节点ID") node_type: str = Field(..., description="节点类型:statement/entity/summary") content_summary: str = Field(..., description="内容摘要") @@ -472,7 +480,8 @@ class ForgettingStatsResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="forbid") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") node_distribution: Dict[str, int] = Field(..., description="节点类型分布") - recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") + recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., + description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") @@ -480,7 +489,7 @@ class ForgettingStatsResponse(BaseModel): class ForgettingReportResponse(BaseModel): """遗忘周期报告响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + merged_count: int = Field(..., description="融合的节点对数量") nodes_before: int = Field(..., description="遗忘前的节点总数") nodes_after: int = Field(..., description="遗忘后的节点总数") @@ -495,7 +504,7 @@ class ForgettingReportResponse(BaseModel): class ForgettingCurvePoint(BaseModel): """遗忘曲线数据点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + day: int = Field(..., description="天数") activation: float = Field(..., description="激活值") retention_rate: float = Field(..., description="保持率(与激活值相同)") @@ -504,7 +513,7 @@ class ForgettingCurvePoint(BaseModel): class ForgettingCurveRequest(BaseModel): """遗忘曲线请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") @@ -513,6 +522,6 @@ class ForgettingCurveRequest(BaseModel): class ForgettingCurveResponse(BaseModel): """遗忘曲线响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") config: Dict[str, Any] = Field(..., description="使用的配置参数") diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 6e7c1ad4..264ae4df 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -11,9 +11,11 @@ import time from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional +from dotenv import load_dotenv +from sqlalchemy.orm import Session + from app.core.logging_config import get_config_logger, get_logger from app.core.memory.analytics.hot_memory_tags import ( - get_hot_memory_tags, get_raw_tags_from_db, filter_tags_with_llm, ) @@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import ( ) from app.services.memory_config_service import MemoryConfigService from app.utils.sse_utils import format_sse_message -from dotenv import load_dotenv -from sqlalchemy.orm import Session logger = get_logger(__name__) config_logger = get_config_logger() @@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector() class MemoryStorageService: """Service for memory storage operations""" - + def __init__(self): logger.info("MemoryStorageService initialized") - + async def get_storage_info(self) -> dict: """ Example wrapper method - retrieves storage information @@ -59,17 +59,17 @@ class MemoryStorageService: Storage information dictionary """ logger.info("Getting storage info ") - + # Empty wrapper - implement your logic here result = { "status": "active", "message": "This is an example wrapper" } - - return result - -class DataConfigService: # 数据配置服务类(PostgreSQL) + return result + + +class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. 使用 SQLAlchemy ORM 进行数据库操作。 @@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return data_list # --- Create --- - def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) # 业务层检查同一工作空间下是否已存在同名配置 if params.workspace_id and params.config_name: from app.models.memory_config_model import MemoryConfig @@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return None # --- Delete --- - def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) + def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) if not success: raise ValueError("未找到配置") return {"affected": 1} # --- Update --- - def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 + def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 config = MemoryConfigRepository.update(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} - def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 + def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 config = MemoryConfigRepository.update_extracted(self.db, update) if not config: raise ValueError("未找到配置") @@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # --- Read --- - def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 + def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) if not result: raise ValueError("未找到配置") return result # --- Read All --- - def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 + def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) # 检查并修正 pruning_scene 与 scene_name 不一致的记录 @@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) except (ValueError, TypeError): config_id_old = None - if config_id_old: - memory_config=config_id_old + memory_config = config_id_old else: - memory_config=config.config_id + memory_config = config.config_id config_dict = { "config_id": memory_config, "config_name": config.config_name, @@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 return self._convert_timestamps_to_format(data_list) - async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: """ 流式执行试运行,产生 SSE 格式的进度事件 @@ -311,14 +309,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) """ from pathlib import Path project_root = str(Path(__file__).resolve().parents[2]) - + try: # 发出初始进度事件 yield format_sse_message("starting", { "message": "开始试运行...", "time": int(time.time() * 1000) }) - + # 步骤 1: 配置加载和验证(数据库优先) payload_cid = str(getattr(payload, "config_id", "") or "").strip() cid: Optional[str] = payload_cid if payload_cid else None @@ -344,27 +342,28 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 关联了本体场景,优先使用 custom_text if hasattr(payload, 'custom_text') and payload.custom_text: dialogue_text = payload.custom_text.strip() - logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") + logger.info( + f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") else: # 如果没有提供 custom_text,回退到 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") + logger.info( + f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") else: # 没有关联本体场景,使用 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}") - + # 验证最终使用的文本不为空 if not dialogue_text: raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)") - - logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") + logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") # 步骤 2: 创建进度回调函数捕获管线进度 # 使用队列在回调和生成器之间传递进度事件 progress_queue: asyncio.Queue = asyncio.Queue() - + async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: """ 进度回调函数,将进度事件放入队列 @@ -375,14 +374,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) data: 可选的结果数据(用于传递节点执行结果) """ await progress_queue.put((stage, message, data)) - + # 步骤 3: 在后台任务中执行管线 async def run_pipeline(): """在后台执行管线并捕获异常""" try: from app.services.pilot_run_service import run_pilot_extraction - - logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") + + logger.info( + f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") await run_pilot_extraction( memory_config=memory_config, dialogue_text=dialogue_text, @@ -391,60 +391,60 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) language=language, ) logger.info("[PILOT_RUN_STREAM] pipeline_main completed") - + # 标记管线完成 await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) except Exception as e: # 将异常放入队列 await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) - + # 启动后台任务 pipeline_task = asyncio.create_task(run_pipeline()) - + # 步骤 4: 从队列中读取进度事件并发出 while True: try: # 等待进度事件,设置超时以检测客户端断开 stage, message, data = await asyncio.wait_for( - progress_queue.get(), + progress_queue.get(), timeout=0.5 ) - + # 检查特殊标记 if stage == "__PIPELINE_COMPLETE__": break elif stage == "__PIPELINE_ERROR__": raise RuntimeError(message) - + # 构建进度事件数据 progress_data = { "message": message, "time": int(time.time() * 1000) } - + # 如果有结果数据,添加到事件中 if data: progress_data["data"] = data - + # 发出进度事件,使用 stage 作为事件类型 yield format_sse_message(stage, progress_data) - + except TimeoutError: # 超时,继续等待(这允许检测客户端断开) continue - + # 等待管线任务完成 await pipeline_task - + # 步骤 5: 读取提取结果 from app.core.config import settings result_path = settings.get_memory_output_path("extracted_result.json") if not os.path.isfile(result_path): raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - + with open(result_path, "r", encoding="utf-8") as rf: extracted_result = json.load(rf) - + # 步骤 6: 计算本体覆盖率并合并到结果中 result_data = { "config_id": cid, @@ -460,15 +460,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) result_data["ontology_coverage"] = ontology_coverage except Exception as cov_err: logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) - + yield format_sse_message("result", result_data) - + # 步骤 7: 发出完成事件 yield format_sse_message("done", { "message": "试运行完成", "time": int(time.time() * 1000) }) - + except asyncio.CancelledError: # 客户端断开连接 logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") @@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "time": int(time.time() * 1000) }) - async def _compute_ontology_coverage( - self, - extracted_result: Dict[str, Any], - memory_config, + self, + extracted_result: Dict[str, Any], + memory_config, ) -> Optional[Dict[str, Any]]: """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 @@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # Ensure env for connector (e.g., NEO4J_PASSWORD) -load_dotenv() -_neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: @@ -664,7 +661,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A # 检查结果是否为空或长度不足 if not result or len(result) < 4: data = { - "total": 0, + "total": 0, "distribution": [ {"type": "dialogue", "count": 0}, {"type": "chunk", "count": 0}, @@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] ) return result + async def analytics_hot_memory_tags( - db: Session, - current_user: User, - limit: int = 10 + db: Session, + current_user: User, + limit: int = 10 ) -> List[Dict[str, Any]]: """ 获取热门记忆标签,按数量排序并返回前N个 @@ -721,27 +719,27 @@ async def analytics_hot_memory_tags( from app.services.memory_dashboard_service import get_workspace_end_users # 使用 asyncio.to_thread 避免阻塞事件循环 end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) - + if not end_users: return [] - + # 步骤1: 收集所有用户的原始标签(不调用LLM) connector = Neo4jConnector() try: all_raw_tags = [] for end_user in end_users: raw_tags = await get_raw_tags_from_db( - connector, - str(end_user.id), - limit=raw_limit, + connector, + str(end_user.id), + limit=raw_limit, by_user=False ) if raw_tags: all_raw_tags.extend(raw_tags) - + if not all_raw_tags: return [] - + # 步骤2: 聚合相同标签的频率 tag_frequency_map = {} for tag_name, frequency in all_raw_tags: @@ -749,36 +747,36 @@ async def analytics_hot_memory_tags( tag_frequency_map[tag_name] += frequency else: tag_frequency_map[tag_name] = frequency - + # 步骤3: 按频率降序排序,取前raw_limit个 sorted_tags = sorted( - tag_frequency_map.items(), - key=lambda x: x[1], + tag_frequency_map.items(), + key=lambda x: x[1], reverse=True )[:raw_limit] - + if not sorted_tags: return [] - + # 步骤4: 只调用一次LLM进行筛选 tag_names = [tag for tag, _ in sorted_tags] - + # 使用第一个用户的end_user_id来获取LLM配置 # 因为同一工作空间下的用户应该使用相同的配置 first_end_user_id = str(end_users[0].id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) - + # 步骤5: 根据LLM筛选结果构建最终列表(保留频率) final_tags = [] for tag, freq in sorted_tags: if tag in filtered_tag_names: final_tags.append((tag, freq)) - + # 步骤6: 只返回前limit个 top_tags = final_tags[:limit] - + return [{"name": t, "frequency": f} for t, f in top_tags] - + finally: await connector.close() @@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> source = "log" total = ( - stats.get("chunk_count", 0) - + stats.get("statements_count", 0) - + stats.get("triplet_entities_count", 0) - + stats.get("triplet_relations_count", 0) - + stats.get("temporal_count", 0) + stats.get("chunk_count", 0) + + stats.get("statements_count", 0) + + stats.get("triplet_entities_count", 0) + + stats.get("triplet_relations_count", 0) + + stats.get("temporal_count", 0) ) # 计算"最新一次活动多久前"(仅日志来源时有效) @@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data - - diff --git a/api/app/tasks.py b/api/app/tasks.py index f5258330..c37e564e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, - user_rag_memory_id: str, - language: str = "zh") -> Dict[str, Any]: +def write_message_task( + self, + end_user_id: str, + message: list[dict], + config_id: str | int, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" +) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) @@ -1105,14 +1111,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100 * '-') - print(actual_config_id) - print(100 * '-') - logger.info( - f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} " + f"(type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error( - f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} " + f"(type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1151,8 +1154,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - logger.info( - f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + logger.info(f"[CELERY WRITE] Task completed successfully " + f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: @@ -1167,7 +1170,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s ) except Exception as _e: logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") - return { "status": "SUCCESS", "result": result, @@ -2672,7 +2674,7 @@ def write_perceptual_memory( ignore_result=False, max_retries=0, acks_late=False, - time_limit=7200, # 2小时硬超时 + time_limit=7200, # 2小时硬超时 soft_time_limit=6900, ) def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: @@ -2749,7 +2751,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s llm_model_id=llm_model_id, ) - logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") await engine.full_clustering(end_user_id) initialized += 1 logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") @@ -2772,12 +2775,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time From dce7206c447893bf89d7d3efcf4046c3139b22d6 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:28:21 +0800 Subject: [PATCH 2/8] fix(celery, rag): unify rag_write return type and remove deprecated downstream calls - Unify the return type of `rag_write` in Celery tasks for consistency. - Remove two deprecated downstream API calls to avoid obsolete dependencies. --- .../controllers/memory_agent_controller.py | 272 +++++++++--------- .../repositories/memory_config_repository.py | 48 +--- api/app/services/memory_agent_service.py | 20 +- api/app/services/memory_konwledges_server.py | 19 +- 4 files changed, 169 insertions(+), 190 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index e3d2bf92..aa4d48e3 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -118,142 +118,142 @@ async def download_log( return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) -@router.post("/writer_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Write service endpoint - processes write operations synchronously - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Response with write operation status - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - - # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id - if storage_type == 'rag': - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - else: - api_logger.warning( - f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") - storage_type = 'neo4j' - else: - api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") - storage_type = 'neo4j' - - api_logger.info( - f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") - try: - messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.end_user_id, - messages_list, - config_id, - db, - storage_type, - user_rag_memory_id, - language - ) - - return success(data=result, msg="写入成功") - except BaseException as e: - # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup - if hasattr(e, 'exceptions'): - error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] - detailed_error = "; ".join(error_messages) - api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) - api_logger.error(f"Write operation error: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/writer_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Async write service endpoint - enqueues write processing to Celery - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Task ID for tracking async operation - Use GET /memory/write_result/{task_id} to check task status and get result - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info( - f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - - task = celery_app.send_task( - "app.core.memory.agent.write_message", - args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] - ) - api_logger.info(f"Write task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="写入任务已提交") - except Exception as e: - api_logger.error(f"Async write operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# @router.post("/writer_service", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Write service endpoint - processes write operations synchronously +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Response with write operation status +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# +# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id +# if storage_type == 'rag': +# if workspace_id: +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: +# user_rag_memory_id = str(knowledge.id) +# else: +# api_logger.warning( +# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") +# storage_type = 'neo4j' +# else: +# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") +# storage_type = 'neo4j' +# +# api_logger.info( +# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") +# try: +# messages_list = memory_agent_service.get_messages_list(user_input) +# result = await memory_agent_service.write_memory( +# user_input.end_user_id, +# messages_list, +# config_id, +# db, +# storage_type, +# user_rag_memory_id, +# language +# ) +# +# return success(data=result, msg="写入成功") +# except BaseException as e: +# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup +# if hasattr(e, 'exceptions'): +# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] +# detailed_error = "; ".join(error_messages) +# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) +# api_logger.error(f"Write operation error: {str(e)}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# +# +# @router.post("/writer_service_async", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server_async( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Async write service endpoint - enqueues write processing to Celery +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Task ID for tracking async operation +# Use GET /memory/write_result/{task_id} to check task status and get result +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info( +# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# if workspace_id: +# +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: user_rag_memory_id = str(knowledge.id) +# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") +# try: +# # 获取标准化的消息列表 +# messages_list = memory_agent_service.get_messages_list(user_input) +# +# task = celery_app.send_task( +# "app.core.memory.agent.write_message", +# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] +# ) +# api_logger.info(f"Write task queued: {task.id}") +# +# return success(data={"task_id": task.id}, msg="写入任务已提交") +# except Exception as e: +# api_logger.error(f"Async write operation failed: {str(e)}") +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) @router.post("/read_service", response_model=ApiResponse) diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22f13449..5c2f81a7 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -309,57 +309,21 @@ class MemoryConfigRepository: Returns: Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None - - Raises: - ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") try: - db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() + stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id) + db_config = db.execute(stmt).scalar_one_or_none() if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - # 更新字段映射 - field_mapping = { - # 模型选择 - "llm_id": "llm_id", - "embedding_id": "embedding_id", - "rerank_id": "rerank_id", - # 记忆萃取引擎 - "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise", - "enable_llm_disambiguation": "enable_llm_disambiguation", - "deep_retrieval": "deep_retrieval", - "t_type_strict": "t_type_strict", - "t_name_strict": "t_name_strict", - "t_overall": "t_overall", - "state": "state", - "chunker_strategy": "chunker_strategy", - # 句子提取 - "statement_granularity": "statement_granularity", - "include_dialogue_context": "include_dialogue_context", - "max_context": "max_context", - # 剪枝配置 - "pruning_enabled": "pruning_enabled", - "pruning_scene": "pruning_scene", - "pruning_threshold": "pruning_threshold", - # 自我反思配置 - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - } + update_data = update.model_dump(exclude_unset=True) + update_data.pop("config_id", None) - has_update = False - for api_field, db_field in field_mapping.items(): - value = getattr(update, api_field, None) - if value is not None: - setattr(db_config, db_field, value) - has_update = True - - if not has_update: - raise ValueError("No fields to update") + for field, value in update_data.items(): + setattr(db_config, field, value) db.commit() db.refresh(db_config) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..514cb12f 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -267,8 +267,16 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, - db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory( + self, + end_user_id: str, + messages: list[dict], + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" + ) -> str: """ Process write operation with config_id @@ -297,8 +305,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError( - f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. " + f"Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -338,8 +346,8 @@ class MemoryAgentService: if storage_type == "rag": # For RAG storage, convert messages to single string message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result + await write_rag(end_user_id, message_text, user_rag_memory_id) + return "success" else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index b8961d33..523adadb 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -341,7 +341,7 @@ async def memory_konwledges_up( ) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") + return db_document async def create_document_chunk( @@ -350,7 +350,7 @@ async def create_document_chunk( create_data: ChunkCreate, db: Session, current_user: User -): +) -> DocumentChunk: """ 创建文档块 @@ -439,10 +439,10 @@ async def create_document_chunk( db_document.chunk_num += 1 db.commit() - return success(data=chunk, msg="文档块创建成功") + return chunk -async def write_rag(end_user_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk: """ 将消息写入 RAG 知识库 @@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id): document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======', document) api_logger.info(f"查找文档结果: document_id={document}") + create_chunks = ChunkCreate(content=message) if document is not None: # 文档已存在,直接添加新块 api_logger.info(f"文档已存在,添加新块: document_id={document}") - create_chunks = ChunkCreate(content=message) result = await create_document_chunk( kb_id=kb_uuid, document_id=uuid.UUID(document), @@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: # 文档不存在,创建新文档 api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") - result = await memory_konwledges_up( + document = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, create_data=create_data, db=db, current_user=current_user ) + result = await create_document_chunk( + kb_id=kb_uuid, + document_id=document.id, + create_data=create_chunks, + db=db, + current_user=current_user + ) # 重新查询刚创建的文档ID new_document_id = find_document_id_by_kb_and_filename( db=db, From 31085ed678ef29d77ba9a7feaa59338a7201d195 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:31:17 +0800 Subject: [PATCH 3/8] fix(workflow): fix memory write behavior in RAG workspace --- .../core/workflow/engine/runtime_schema.py | 14 ++++- api/app/core/workflow/engine/state_manager.py | 9 ++- api/app/core/workflow/engine/variable_pool.py | 10 ++- api/app/core/workflow/executor.py | 20 ++++-- api/app/core/workflow/nodes/memory/node.py | 61 ++++++++++++++++--- api/app/services/workflow_service.py | 29 ++++++++- api/app/services/workspace_service.py | 2 +- 7 files changed, 128 insertions(+), 17 deletions(-) diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index e4bf65af..48eafaa9 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,14 +12,26 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + memory_storage_type: str + user_rag_memory_id: str checkpoint_config: RunnableConfig @classmethod - def create(cls, execution_id: str, workspace_id: str, user_id: str): + def create( + cls, + execution_id: str, + workspace_id: str, + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str + ): return cls( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id, + checkpoint_config=RunnableConfig( configurable={ "thread_id": uuid.uuid4(), diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 0a4a1463..2da0d3a8 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -33,6 +33,8 @@ class WorkflowState(dict): "workspace_id", "user_id", "activate", + "memory_storage_type", + "user_rag_memory_id" }) __optional_keys__ = frozenset({ "error", @@ -62,6 +64,9 @@ class WorkflowState(dict): # node activate status activate: Annotated[dict[str, bool], merge_activate_state] + memory_storage_type: str + user_rag_memory_id: str + class WorkflowStateManager: def create_initial_state( @@ -85,7 +90,9 @@ class WorkflowStateManager: looping=0, activate={ start_node_id: True - } + }, + memory_storage_type=execution_context.memory_storage_type, + user_rag_memory_id=execution_context.user_rag_memory_id ) @staticmethod diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index cf6f4a7b..d4e1b488 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE -from app.core.workflow.variable.variable_objects import T, create_variable_instance +from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable logger = logging.getLogger(__name__) @@ -373,6 +373,14 @@ class VariablePool: def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) + def is_file_variable(self, selector): + variable_struct = self._get_variable_struct(selector) + if isinstance(variable_struct, FileVariable): + return True + elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: + return True + return False + def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c9ed6e65..6a127e96 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -409,7 +409,9 @@ async def execute_workflow( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ) -> dict[str, Any]: """ Execute a workflow (convenience function, non-streaming). @@ -420,6 +422,8 @@ async def execute_workflow( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Returns: dict: Workflow execution result. @@ -427,7 +431,9 @@ async def execute_workflow( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, @@ -441,7 +447,9 @@ async def execute_workflow_stream( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ): """ Execute a workflow in streaming mode (convenience function). @@ -452,6 +460,8 @@ async def execute_workflow_stream( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Yields: dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. @@ -459,7 +469,9 @@ async def execute_workflow_stream( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 1d42e82e..82363056 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -1,3 +1,4 @@ +import re from typing import Any from app.core.workflow.engine.state_manager import WorkflowState @@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read +from app.schemas import FileInput from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task @@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode): search_switch=self.typed_config.search_switch, history=[], db=db, - storage_type="neo4j", - user_rag_memory_id="" + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) @@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} + @staticmethod + def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]: + variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' + variable_pattern = re.compile(variable_pattern_string) + variables = variable_pattern.findall(content) + file_variables = [] + for variable in variables: + if variable_pool.is_file_variable(variable): + file_variables.append(variable) + for var in file_variables: + content = content.replace(var, "") + return file_variables, content + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", variable_pool) @@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] + multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode): }) for message in self.typed_config.messages: + file_variables, content = self._extract_multimodal_memory_variables( + message.content, + variable_pool + ) + file_info = [] + for var in file_variables: + instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var) + if isinstance(instence, FileVariable): + file_info.append(FileInput( + type=instence.value.type, + transfer_method=instence.value.transfer_method, + upload_file_id=instence.value.file_id, + url=instence.value.url, + file_type=instence.value.origin_file_type + ).model_dump()) + elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable: + for file_instence in instence.value: + file_info.append(FileInput( + type=file_instence.value.type, + transfer_method=file_instence.value.transfer_method, + upload_file_id=file_instence.value.file_id, + url=file_instence.value.url, + file_type=file_instence.value.origin_file_type + ).model_dump()) + multimodal_memories.append({ + "role": message.role, + "files": file_info + }) messages.append({ "role": message.role, - "content": self._render_template(message.content, variable_pool) + "content": self._render_template(content, variable_pool) }) write_message_task.delay( - end_user_id, - messages, - str(self.typed_config.config_id), - "neo4j", - "" + end_user_id=end_user_id, + message=messages, + config_id=str(self.typed_config.config_id), + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) return "success" diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 56f34496..db659268 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories import knowledge_repository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService +from app.services.workspace_service import get_workspace_storage_type_without_auth logger = logging.getLogger(__name__) @@ -536,6 +538,25 @@ class WorkflowService: mapped = internal_event return mapped + def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: + storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) + user_rag_memory_id = "" + if storage_type == "rag": + knowledge = knowledge_repository.get_knowledge_by_name( + db=self.db, + name="USER_RAG_MERORY", + workspace_id=workspace_id + ) + if knowledge: + user_rag_memory_id = str(knowledge.id) + else: + logger.warning( + f"No knowledge base named 'USER_RAG_MEMORY' found, " + f"workspace_id: {workspace_id}, will use neo4j storage" + ) + storage_type = 'neo4j' + return storage_type, user_rag_memory_id + # ==================== 工作流执行 ==================== async def run( @@ -603,6 +624,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files message_id = uuid.uuid4() # 更新状态为运行中 @@ -627,7 +649,9 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ) # 更新执行结果 if result.get("status") == "completed": @@ -776,6 +800,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -797,6 +822,8 @@ class WorkflowService: execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index cefb8380..90b5cf65 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -863,7 +863,7 @@ def get_workspace_storage_type( def get_workspace_storage_type_without_auth( db: Session, workspace_id: uuid.UUID, -) -> Optional[str]: +) -> str: """获取工作空间的存储类型(无需权限验证,用于公开分享等场景) Args: From 2ff81ba101b422e600982dcc96fc673bc4684223 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 23 Mar 2026 16:33:25 +0800 Subject: [PATCH 4/8] feat(memory): support perception-aware memory writing in workflow and Neo4j nodes --- .../core/memory/agent/utils/write_tools.py | 65 ++- api/app/core/memory/models/graph_models.py | 18 + .../deduplication/two_stage_dedup.py | 24 +- .../extraction_orchestrator.py | 465 +++++++++--------- .../knowledge_extraction/memory_summary.py | 1 - api/app/core/workflow/engine/variable_pool.py | 4 +- api/app/core/workflow/nodes/base_node.py | 8 +- api/app/core/workflow/nodes/llm/node.py | 14 +- api/app/core/workflow/nodes/memory/node.py | 1 + api/app/models/memory_config_model.py | 3 + .../repositories/memory_config_repository.py | 57 +-- api/app/repositories/neo4j/add_nodes.py | 111 ++++- api/app/repositories/neo4j/cypher_queries.py | 33 ++ api/app/schemas/memory_config_schema.py | 6 + api/app/services/app_chat_service.py | 4 +- api/app/services/draft_run_service.py | 4 +- api/app/services/memory_agent_service.py | 85 ++-- api/app/services/memory_api_service.py | 85 ++-- api/app/services/memory_config_service.py | 194 +++++--- api/app/services/memory_perceptual_service.py | 120 ++++- api/app/services/multimodal_service.py | 31 +- api/app/tasks.py | 6 +- 22 files changed, 820 insertions(+), 519 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..147a0316 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio +import uuid import time from datetime import datetime @@ -13,28 +14,31 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ + memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context +from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ + add_perceptual_dialogue_edges from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig - load_dotenv() logger = get_agent_logger(__name__) async def write( - end_user_id: str, - memory_config: MemoryConfig, - messages: list, - ref_id: str = "wyl20251027", - language: str = "zh", + end_user_id: str, + memory_config: MemoryConfig, + messages: list, + file_content: list[MemoryPerceptualModel], + ref_id: str = "", + language: str = "zh", ) -> None: """ Execute the complete knowledge extraction pipeline. @@ -43,9 +47,12 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - ref_id: Reference ID, defaults to "wyl20251027" + file_content: mutilmodal message list + ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ + if not ref_id: + ref_id = uuid.uuid4().hex # Extract config values embedding_model_id = str(memory_config.embedding_model_id) chunker_strategy = memory_config.chunker_strategy @@ -99,14 +106,14 @@ async def write( if memory_config.scene_id: try: from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene - + with get_db_context() as db: ontology_types = load_ontology_types_for_scene( scene_id=memory_config.scene_id, workspace_id=memory_config.workspace_id, db=db ) - + if ontology_types: logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" @@ -173,7 +180,8 @@ async def write( schedule_clustering_after_write( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + embedding_model_id=str( + memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) break else: @@ -208,9 +216,8 @@ async def write( summaries = await memory_summary_generation( chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language ) - + ms_connector = Neo4jConnector() try: - ms_connector = Neo4jConnector() await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector) finally: @@ -223,6 +230,34 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) + # Step 5: Save perceptual memory to Neo4j + step_start = time.time() + if file_content: + try: + pc_connector = Neo4jConnector() + try: + created_ids = await add_perceptual_nodes( + perceptuals=file_content, + connector=pc_connector, + embedder_client=embedder_client, + ) + # 如果有 ref_id,建立感知记忆与对话的关联 + if ref_id and created_ids: + await add_perceptual_dialogue_edges( + perceptuals=file_content, + dialog_id=ref_id, + connector=pc_connector, + ) + logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") + finally: + try: + await pc_connector.close() + except Exception: + pass + except Exception as e: + logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) + log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) + # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) @@ -251,4 +286,4 @@ async def write( logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file + logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 5a2d8c2e..fb251f1f 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -553,3 +553,21 @@ class MemorySummaryNode(Node): ge=0, description="Total number of times this node has been accessed (reset to 1 on creation)" ) + + +class MutlimodalNode(Node): + """Node representing a multimodal message in the knowledge graph. + + Attributes: + dialog_id: ID of the parent dialog + message_id: ID of the message + metadata: Additional message metadata + embedding: Optional embedding vector for the message + """ + dialog_id: str = Field(..., description="ID of the parent dialog") + message_id: str = Field(..., description="ID of the message") + summary: str = Field(..., description="The text content of the message") + file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") + file_path: List[str] = Field(..., description="List of file paths for multimodal content") + metadata: dict = Field(default_factory=dict, description="Additional message metadata") + embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index f28b8a5f..4b9c5718 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def dedup_layers_and_merge_and_return( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - pipeline_config: ExtractionPipelineConfig, - connector: Optional[Neo4jConnector] = None, - llm_client = None, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + pipeline_config: ExtractionPipelineConfig, + connector: Optional[Neo4jConnector] = None, + llm_client=None, ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return( List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], - dict, # 新增:返回去重详情 + dict ]: """ 执行两层实体去重与融合: diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 00d06f72..6e94a84f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,11 +31,10 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode, + StatementNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList -from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, ) @@ -46,7 +45,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb embedding_generation, generate_entity_embeddings_from_triplets, ) - # 导入各个提取模块 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( StatementExtractor, @@ -90,16 +88,16 @@ class ExtractionOrchestrator: """ def __init__( - self, - llm_client: LLMClient, - embedder_client: OpenAIEmbedderClient, - connector: Neo4jConnector, - config: Optional[ExtractionPipelineConfig] = None, - progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, - embedding_id: Optional[str] = None, - ontology_types: Optional[OntologyTypeList] = None, - enable_general_types: bool = True, - language: str = "zh", + self, + llm_client: LLMClient, + embedder_client: OpenAIEmbedderClient, + connector: Neo4jConnector, + config: Optional[ExtractionPipelineConfig] = None, + progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, + embedding_id: Optional[str] = None, + ontology_types: Optional[OntologyTypeList] = None, + enable_general_types: bool = True, + language: str = "zh", ): """ 初始化流水线编排器 @@ -123,7 +121,7 @@ class ExtractionOrchestrator: self.progress_callback = progress_callback # 保存进度回调函数 self.embedding_id = embedding_id # 保存嵌入模型ID self.language = language # 保存语言配置 - + # 处理本体类型配置 # 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并 # 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合 @@ -146,7 +144,7 @@ class ExtractionOrchestrator: self.ontology_types = ontology_types if not enable_general_types and ontology_types: logger.info("enable_general_types=False,仅使用场景类型") - + # 保存去重消歧的详细记录(内存中的数据结构) self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录 self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录 @@ -157,19 +155,25 @@ class ExtractionOrchestrator: llm_client=llm_client, config=self.config.statement_extraction, ) - self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) + self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types, + language=language) self.temporal_extractor = TemporalExtractor(llm_client=llm_client) logger.info("ExtractionOrchestrator 初始化完成") async def run( - self, - dialog_data_list: List[DialogData], - is_pilot_run: bool = False, - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialog_data_list: List[DialogData], + is_pilot_run: bool = False, + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -202,13 +206,12 @@ class ExtractionOrchestrator: # 步骤 1: 陈述句提取 logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") dialog_data_list = await self._extract_statements(dialog_data_list) - + # 收集陈述句内容和统计数量 all_statements_list = [] for dialog in dialog_data_list: for chunk in dialog.chunks: all_statements_list.extend(chunk.statements) - len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") @@ -220,7 +223,7 @@ class ExtractionOrchestrator: chunk_embedding_maps, dialog_embeddings, ) = await self._parallel_extract_and_embed(dialog_data_list) - + # 收集实体和三元组内容,并统计数量 all_entities_list = [] all_triplets_list = [] @@ -229,10 +232,6 @@ class ExtractionOrchestrator: if triplet_info: all_entities_list.extend(triplet_info.entities) all_triplets_list.extend(triplet_info.triplets) - - len(all_entities_list) - len(all_triplets_list) - sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") @@ -252,9 +251,9 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") - + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 - + ( dialogue_nodes, chunk_nodes, @@ -273,9 +272,9 @@ class ExtractionOrchestrator: logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)") else: logger.info("步骤 6/6: 两阶段去重和消歧") - + # 注意:deduplication 消息已在创建节点和边完成后立即发送 - + result = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -287,8 +286,6 @@ class ExtractionOrchestrator: dialog_data_list, ) - - logger.info(f"知识提取流水线运行完成({mode_str})") return result @@ -297,7 +294,7 @@ class ExtractionOrchestrator: raise async def _extract_statements( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ 从对话中提取陈述句(流式输出版本:边提取边发送进度) @@ -313,7 +310,7 @@ class ExtractionOrchestrator: # 收集所有分块及其元数据 all_chunks = [] chunk_metadata = [] # (dialog_idx, chunk_idx) - + for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): @@ -321,7 +318,7 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") - + # 用于跟踪已完成的分块数量 completed_chunks = 0 total_chunks = len(all_chunks) @@ -332,7 +329,7 @@ class ExtractionOrchestrator: chunk, end_user_id, dialogue_content = chunk_data try: statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) - + # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 completed_chunks += 1 @@ -347,11 +344,11 @@ class ExtractionOrchestrator: "statement_index_in_chunk": idx + 1 } await self.progress_callback( - "knowledge_extraction_result", - f"陈述句提取中 ({completed_chunks}/{total_chunks})", + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", stmt_result ) - + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") @@ -381,7 +378,7 @@ class ExtractionOrchestrator: # 保存陈述句到文件(试运行和正式模式都需要) self.statement_extractor.save_statements(all_statements) - + logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") # 试运行模式下,所有分块提取完成后发送完成事件 @@ -395,7 +392,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _extract_triplets( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取三元组(流式输出版本:边提取边发送进度) @@ -411,7 +408,7 @@ class ExtractionOrchestrator: # 收集所有陈述句及其元数据 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, chunk_content) - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -419,7 +416,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") - + # 用于跟踪已完成的陈述句数量 completed_statements = 0 len(all_statements) @@ -430,11 +427,11 @@ class ExtractionOrchestrator: statement, chunk_content = stmt_data try: triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) - + # 注意:不再发送三元组提取的流式输出 # 三元组提取在后台执行,但不向前端发送详细信息 completed_statements += 1 - + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") @@ -450,7 +447,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 triplet_maps = [{} for _ in dialog_data_list] all_responses = [] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -478,7 +475,7 @@ class ExtractionOrchestrator: return triplet_maps async def _extract_temporal( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取时间信息(流式输出版本:边提取边发送进度) @@ -502,13 +499,13 @@ class ExtractionOrchestrator: temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) temporal_maps.append(temporal_map) return temporal_maps - + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, ref_dates) - + for d_idx, dialog in enumerate(dialog_data_list): # 获取参考日期 ref_dates = {} @@ -517,11 +514,11 @@ class ExtractionOrchestrator: ref_dates['conversation_date'] = dialog.metadata['conversation_date'] if 'publication_date' in dialog.metadata: ref_dates['publication_date'] = dialog.metadata['publication_date'] - + if not ref_dates: from datetime import datetime ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - + for chunk in dialog.chunks: for statement in chunk.statements: # 跳过 ATEMPORAL 类型的陈述句 @@ -531,7 +528,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") - + # 用于跟踪已完成的时间提取数量 completed_temporal = 0 len(all_statements) @@ -542,11 +539,11 @@ class ExtractionOrchestrator: statement, ref_dates = stmt_data try: temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) - + # 注意:不再发送时间提取的流式输出 # 时间提取在后台执行,但不向前端发送详细信息 completed_temporal += 1 - + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") @@ -559,7 +556,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 temporal_maps = [{} for _ in dialog_data_list] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -585,7 +582,7 @@ class ExtractionOrchestrator: return temporal_maps async def _extract_emotions( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) @@ -601,36 +598,36 @@ class ExtractionOrchestrator: # 收集所有陈述句及其配置 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id) - + # 获取第一个对话的config_id来加载配置 config_id = None if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): config_id = dialog_data_list[0].config_id - + # 加载MemoryConfig memory_config = None if config_id: try: from app.db import SessionLocal from app.repositories.memory_config_repository import MemoryConfigRepository - + db = SessionLocal() try: memory_config = MemoryConfigRepository.get_by_id(db, config_id) finally: db.close() - + if memory_config and not memory_config.emotion_enabled: logger.info("情绪提取已在配置中禁用,跳过情绪提取") return [{} for _ in dialog_data_list] - + except Exception as e: logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取") return [{} for _ in dialog_data_list] else: logger.info("未找到config_id,跳过情绪提取") return [{} for _ in dialog_data_list] - + # 如果配置未启用情绪提取,直接返回空映射 if not memory_config or not memory_config.emotion_enabled: logger.info("情绪提取未启用,跳过") @@ -639,7 +636,7 @@ class ExtractionOrchestrator: # 收集所有陈述句(只收集 speaker 为 "user" 的) total_statements = 0 filtered_statements = 0 - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -655,12 +652,12 @@ class ExtractionOrchestrator: # 初始化情绪提取服务 # 如果 emotion_model_id 为空,回退到工作空间默认 LLM from app.services.emotion_extraction_service import EmotionExtractionService - + emotion_model_id = memory_config.emotion_model_id if not emotion_model_id and memory_config.workspace_id: from app.repositories.workspace_repository import get_workspace_models_configs from app.db import SessionLocal - + db = SessionLocal() try: workspace_models = get_workspace_models_configs(db, memory_config.workspace_id) @@ -669,7 +666,7 @@ class ExtractionOrchestrator: logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}") finally: db.close() - + emotion_service = EmotionExtractionService( llm_id=emotion_model_id if emotion_model_id else None ) @@ -689,7 +686,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 emotion_maps = [{} for _ in dialog_data_list] successful_extractions = 0 - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -706,7 +703,7 @@ class ExtractionOrchestrator: return emotion_maps async def _parallel_extract_and_embed( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[Dict[str, Any]], List[Dict[str, Any]], @@ -757,7 +754,7 @@ class ExtractionOrchestrator: triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - + if isinstance(results[3], Exception): logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] @@ -777,7 +774,7 @@ class ExtractionOrchestrator: ) async def _generate_basic_embeddings( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: """ 生成基础嵌入向量(陈述句、分块、对话) @@ -810,7 +807,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") raise ValueError("embedding_id is required but was not provided") - + # 只生成陈述句、分块和对话的嵌入(不包括实体) statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation( dialog_data_list, self.embedding_id @@ -836,7 +833,7 @@ class ExtractionOrchestrator: ) async def _generate_entity_embeddings( - self, triplet_maps: List[Dict[str, Any]] + self, triplet_maps: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 生成实体嵌入向量 @@ -861,7 +858,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") return triplet_maps - + # 生成实体嵌入 updated_triplet_maps = await generate_entity_embeddings_from_triplets( triplet_maps, self.embedding_id @@ -874,17 +871,15 @@ class ExtractionOrchestrator: logger.error(f"实体嵌入生成失败: {e}", exc_info=True) return triplet_maps - - async def _assign_extracted_data( - self, - dialog_data_list: List[DialogData], - temporal_maps: List[Dict[str, Any]], - triplet_maps: List[Dict[str, Any]], - emotion_maps: List[Dict[str, Any]], - statement_embedding_maps: List[Dict[str, List[float]]], - chunk_embedding_maps: List[Dict[str, List[float]]], - dialog_embeddings: List[List[float]], + self, + dialog_data_list: List[DialogData], + temporal_maps: List[Dict[str, Any]], + triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], + statement_embedding_maps: List[Dict[str, List[float]]], + chunk_embedding_maps: List[Dict[str, List[float]]], + dialog_embeddings: List[List[float]], ) -> List[DialogData]: """ 将提取的数据赋值到语句 @@ -906,12 +901,12 @@ class ExtractionOrchestrator: # 确保列表长度匹配 expected_length = len(dialog_data_list) if ( - len(temporal_maps) != expected_length - or len(triplet_maps) != expected_length - or len(emotion_maps) != expected_length - or len(statement_embedding_maps) != expected_length - or len(chunk_embedding_maps) != expected_length - or len(dialog_embeddings) != expected_length + len(temporal_maps) != expected_length + or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length + or len(statement_embedding_maps) != expected_length + or len(chunk_embedding_maps) != expected_length + or len(dialog_embeddings) != expected_length ): logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " @@ -999,7 +994,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _create_nodes_and_edges( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -1007,7 +1002,7 @@ class ExtractionOrchestrator: List[ExtractedEntityNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge], + List[EntityEntityEdge] ]: """ 创建图数据库节点和边 @@ -1021,7 +1016,7 @@ class ExtractionOrchestrator: 包含所有节点和边的元组 """ logger.info("开始创建节点和边") - + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] @@ -1034,7 +1029,7 @@ class ExtractionOrchestrator: # 用于去重的集合 entity_id_set = set() - + # 用于跟踪进度 total_dialogs = len(dialog_data_list) processed_dialogs = 0 @@ -1083,15 +1078,19 @@ class ExtractionOrchestrator: name=f"Statement_{statement.id}", # 添加必需的 name 字段 chunk_id=chunk.id, stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 - temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 - connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), + # 添加必需的 temporal_info 字段 + connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, - valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, + valid_at=statement.temporal_validity.valid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, + invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, @@ -1120,7 +1119,7 @@ class ExtractionOrchestrator: # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} - + # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): # 映射实体索引到实体ID(使用多个键以提高容错性) @@ -1128,7 +1127,7 @@ class ExtractionOrchestrator: entity_idx_to_id[entity.entity_idx] = entity.id # 2. 使用枚举索引(从0开始) entity_idx_to_id[entity_idx] = entity.id - + if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_node = ExtractedEntityNode( @@ -1141,7 +1140,8 @@ class ExtractionOrchestrator: example=getattr(entity, 'example', ''), # 新增:传递示例字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 @@ -1171,7 +1171,7 @@ class ExtractionOrchestrator: # 将三元组中的整数索引映射到实体ID subject_entity_id = entity_idx_to_id.get(triplet.subject_id) object_entity_id = entity_idx_to_id.get(triplet.object_id) - + # 只有当两个实体ID都存在时才创建边 if subject_entity_id and object_entity_id: entity_entity_edge = EntityEntityEdge( @@ -1186,7 +1186,7 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) - + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) if self.progress_callback and len(entity_entity_edges) <= 10: # 获取实体名称 @@ -1202,8 +1202,8 @@ class ExtractionOrchestrator: "dialog_progress": f"{processed_dialogs}/{total_dialogs}" } await self.progress_callback( - "creating_nodes_edges_result", - f"关系创建中 ({processed_dialogs}/{total_dialogs})", + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", relationship_result ) else: @@ -1211,7 +1211,7 @@ class ExtractionOrchestrator: missing_subject = "subject" if not subject_entity_id else "" missing_object = "object" if not object_entity_id else "" missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" - + logger.debug( f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " f"subject_id={triplet.subject_id} ({triplet.subject_name}), " @@ -1228,7 +1228,7 @@ class ExtractionOrchestrator: f"陈述句-实体边: {len(statement_entity_edges)}, " f"实体-实体边: {len(entity_entity_edges)}" ) - + # 进度回调:创建节点和边完成,传递结果统计 # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: @@ -1254,19 +1254,24 @@ class ExtractionOrchestrator: ) async def _run_dedup_and_write_summary( - self, - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 执行两阶段去重并写入汇总 @@ -1288,11 +1293,11 @@ class ExtractionOrchestrator: - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) """ logger.info("开始两阶段实体去重和消歧") - + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") - + logger.info( f"去重前: {len(entity_nodes)} 个实体节点, " f"{len(statement_entity_edges)} 条陈述句-实体边, " @@ -1307,7 +1312,7 @@ class ExtractionOrchestrator: from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( deduplicate_entities_and_edges, ) - + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, @@ -1317,10 +1322,10 @@ class ExtractionOrchestrator: dedup_config=self.config.deduplication, llm_client=self.llm_client, ) - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes) - + result_tuple = ( dialogue_nodes, chunk_nodes, @@ -1330,7 +1335,7 @@ class ExtractionOrchestrator: dedup_statement_entity_edges, dedup_entity_entity_edges, ) - + final_entity_nodes = dedup_entity_nodes final_statement_entity_edges = dedup_statement_entity_edges final_entity_entity_edges = dedup_entity_entity_edges @@ -1361,7 +1366,7 @@ class ExtractionOrchestrator: final_entity_entity_edges, dedup_details, ) = result_tuple - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) @@ -1375,12 +1380,12 @@ class ExtractionOrchestrator: f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { @@ -1391,10 +1396,10 @@ class ExtractionOrchestrator: "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { @@ -1407,14 +1412,13 @@ class ExtractionOrchestrator: "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) - + # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback( len(entity_nodes), len(final_entity_nodes), len(statement_entity_edges), len(final_statement_entity_edges), len(entity_entity_edges), len(final_entity_entity_edges) ) - # 写入提取结果汇总(试运行和正式模式都需要生成) try: @@ -1436,10 +1440,10 @@ class ExtractionOrchestrator: raise def _save_dedup_details( - self, - dedup_details: Dict[str, Any], - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ): """ 保存去重消歧的详细记录到实例变量(基于内存数据结构) @@ -1452,7 +1456,7 @@ class ExtractionOrchestrator: try: # 保存ID重定向映射 self.id_redirect_map = dedup_details.get("id_redirect", {}) - + # 处理精确匹配的合并记录 exact_merge_map = dedup_details.get("exact_merge_map", {}) for key, info in exact_merge_map.items(): @@ -1466,7 +1470,7 @@ class ExtractionOrchestrator: "merged_count": len(merged_ids), "merged_ids": list(merged_ids) }) - + # 处理模糊匹配的合并记录 fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", []) for record in fuzzy_merge_records: @@ -1486,7 +1490,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}") - + # 处理LLM去重的合并记录 llm_decision_records = dedup_details.get("llm_decision_records", []) for record in llm_decision_records: @@ -1505,7 +1509,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}") - + # 处理消歧记录 disamb_records = dedup_details.get("disamb_records", []) for record in disamb_records: @@ -1520,14 +1524,14 @@ class ExtractionOrchestrator: entity1_type = match.group(2) match.group(3).strip() entity2_type = match.group(4) - + # 提取置信度和原因 conf_match = re.search(r"conf=([0-9.]+)", str(record)) confidence = conf_match.group(1) if conf_match else "unknown" - + reason_match = re.search(r"reason=([^|]+)", str(record)) reason = reason_match.group(1).strip() if reason_match else "" - + self.dedup_disamb_records.append({ "entity_name": entity1_name, "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", @@ -1536,16 +1540,17 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") - - logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") - + + logger.info( + f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") + except Exception as e: logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) async def _analyze_entity_merges( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) @@ -1566,28 +1571,28 @@ class ExtractionOrchestrator: key=lambda x: x.get("merged_count", 0), reverse=True ) - + merge_info = [] for record in sorted_records: merge_info.append({ "main_entity_name": record.get("entity_name", "未知实体"), "merged_count": record.get("merged_count", 1) }) - + return merge_info - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体合并记录") return [] - + except Exception as e: logger.error(f"分析实体合并情况失败: {e}", exc_info=True) return [] async def _analyze_entity_disambiguation( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) @@ -1603,11 +1608,11 @@ class ExtractionOrchestrator: # 直接使用保存的消歧记录 if self.dedup_disamb_records: return self.dedup_disamb_records - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体消歧记录") return [] - + except Exception as e: logger.error(f"分析实体消歧情况失败: {e}", exc_info=True) return [] @@ -1624,7 +1629,7 @@ class ExtractionOrchestrator: """ type_mapping = { "Person": "人物实体节点", - "Organization": "组织实体节点", + "Organization": "组织实体节点", "ORG": "组织实体节点", "Location": "地点实体节点", "LOC": "地点实体节点", @@ -1645,9 +1650,9 @@ class ExtractionOrchestrator: return type_mapping.get(entity_type, f"{entity_type}实体节点") async def _output_relationship_creation_results( - self, - entity_entity_edges: List[EntityEntityEdge], - entity_nodes: List[ExtractedEntityNode] + self, + entity_entity_edges: List[EntityEntityEdge], + entity_nodes: List[ExtractedEntityNode] ): """ 输出关系创建结果 @@ -1659,13 +1664,13 @@ class ExtractionOrchestrator: try: # 创建实体ID到名称的映射 entity_id_to_name = {node.id: node.name for node in entity_nodes} - + # 输出关系创建结果 for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系 source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}") target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}") relation_type = edge.relation_type - + relationship_result = { "result_type": "relationship_creation", "relationship_index": i + 1, @@ -1674,20 +1679,20 @@ class ExtractionOrchestrator: "target_entity": target_name, "relationship_text": f"{source_name} -[{relation_type}]-> {target_name}" } - + await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result) - + except Exception as e: logger.error(f"输出关系创建结果失败: {e}", exc_info=True) async def _send_dedup_progress_callback( - self, - original_entities: int, - final_entities: int, - original_stmt_edges: int, - final_stmt_edges: int, - original_ent_edges: int, - final_ent_edges: int, + self, + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, ): """ 发送去重消歧完成的进度回调,传递具体的去重和消歧效果 @@ -1703,19 +1708,20 @@ class ExtractionOrchestrator: try: # 解析去重消歧报告文件,获取具体的去重和消歧效果 dedup_details = await self._parse_dedup_report() - + # 计算去重效果统计 entities_reduced = original_entities - final_entities stmt_edges_reduced = original_stmt_edges - final_stmt_edges ent_edges_reduced = original_ent_edges - final_ent_edges - + # 构建进度回调数据 dedup_stats = { "entities": { "original_count": original_entities, "final_count": final_entities, "reduced_count": entities_reduced, - "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, + "reduction_rate": round(entities_reduced / original_entities * 100, + 1) if original_entities > 0 else 0, }, "statement_entity_edges": { "original_count": original_stmt_edges, @@ -1734,9 +1740,9 @@ class ExtractionOrchestrator: "total_disambiguations": dedup_details.get("total_disambiguations", 0), } } - + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) - + except Exception as e: logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True) # 即使解析失败,也发送基本的统计信息 @@ -1766,12 +1772,12 @@ class ExtractionOrchestrator: disamb_examples = [] total_merges = 0 total_disambiguations = 0 - + # 处理合并记录 for record in self.dedup_merge_records: merge_count = record.get("merged_count", 0) total_merges += merge_count - + dedup_examples.append({ "type": record.get("type", "未知"), "entity_name": record.get("entity_name", "未知实体"), @@ -1779,30 +1785,31 @@ class ExtractionOrchestrator: "merge_count": merge_count, "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个" }) - + # 处理消歧记录 for record in self.dedup_disamb_records: total_disambiguations += 1 - + # 从消歧类型中提取实体类型信息 disamb_type = record.get("disamb_type", "") entity_name = record.get("entity_name", "未知实体") - + disamb_examples.append({ "entity1_name": entity_name, - "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", + "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", + "").strip() if "vs" in disamb_type else "未知", "entity2_name": entity_name, "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "description": f"{entity_name},消歧区分成功" }) - + return { "dedup_examples": dedup_examples[:5], # 只返回前5个示例 "disamb_examples": disamb_examples[:5], # 只返回前5个示例 "total_merges": total_merges, "total_disambiguations": total_disambiguations, } - + except Exception as e: logger.error(f"获取去重报告失败: {e}", exc_info=True) return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0} @@ -1815,9 +1822,9 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "group_1", - indices: Optional[List[int]] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "group_1", + indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 @@ -1831,7 +1838,7 @@ async def get_chunked_dialogs( """ import json import re - + # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") with open(testdata_path, "r", encoding="utf-8") as f: @@ -1845,7 +1852,7 @@ async def get_chunked_dialogs( else: # 默认使用所有数据 selected_data = test_data - + for data in selected_data: # 解析对话上下文 context_text = data["context"] @@ -1861,7 +1868,7 @@ async def get_chunked_dialogs( if m: y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - + dialog_metadata: Dict[str, Any] = {} if conv_date: dialog_metadata["conversation_date"] = conv_date @@ -1890,7 +1897,7 @@ async def get_chunked_dialogs( end_user_id=end_user_id, metadata=dialog_metadata, ) - + # 创建分块器并处理对话 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, @@ -1913,7 +1920,7 @@ async def get_chunked_dialogs( from app.core.config import settings settings.ensure_memory_output_dir() output_path = settings.get_memory_output_path("chunker_test_output.txt") - + import json with open(output_path, "w", encoding="utf-8") as f: json.dump( @@ -1924,10 +1931,10 @@ async def get_chunked_dialogs( def preprocess_data( - input_path: Optional[str] = None, - output_path: Optional[str] = None, - skip_cleaning: bool = True, - indices: Optional[List[int]] = None + input_path: Optional[str] = None, + output_path: Optional[str] = None, + skip_cleaning: bool = True, + indices: Optional[List[int]] = None ) -> List[DialogData]: """数据预处理 @@ -1946,7 +1953,8 @@ def preprocess_data( ) preprocessor = DataPreprocessor() try: - cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) + cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, + skip_cleaning=skip_cleaning, indices=indices) logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: @@ -1955,9 +1963,9 @@ def preprocess_data( async def get_chunked_dialogs_from_preprocessed( - data: List[DialogData], - chunker_strategy: str = "RecursiveChunker", - llm_client: Optional[Any] = None, + data: List[DialogData], + chunker_strategy: str = "RecursiveChunker", + llm_client: Optional[Any] = None, ) -> List[DialogData]: """从预处理后的数据中生成分块 @@ -1972,31 +1980,31 @@ async def get_chunked_dialogs_from_preprocessed( logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") - + all_chunked_dialogs: List[DialogData] = [] from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, ) - + for dialog_data in data: chunker = DialogueChunker(chunker_strategy, llm_client=llm_client) chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = chunks all_chunked_dialogs.append(dialog_data) - + return all_chunked_dialogs async def get_chunked_dialogs_with_preprocessing( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "default", - user_id: str = "default", - apply_id: str = "default", - indices: Optional[List[int]] = None, - input_data_path: Optional[str] = None, - llm_client: Optional[Any] = None, - skip_cleaning: bool = True, - pruning_config: Optional[Dict] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "default", + user_id: str = "default", + apply_id: str = "default", + indices: Optional[List[int]] = None, + input_data_path: Optional[str] = None, + llm_client: Optional[Any] = None, + skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2020,7 +2028,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path = os.path.join( os.path.dirname(__file__), "../../data", "testdata.json" ) - + # 步骤1: 数据预处理(包含索引筛选) from app.core.config import settings settings.ensure_memory_output_dir() @@ -2030,37 +2038,38 @@ async def get_chunked_dialogs_with_preprocessing( skip_cleaning=skip_cleaning, indices=indices, ) - + # 设置 end_user_id for dd in preprocessed_data: dd.end_user_id = end_user_id - + # 步骤2: 语义剪枝 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) from app.core.memory.models.config_models import PruningConfig - + # 构建剪枝配置 if pruning_config: # 使用传入的配置 config = PruningConfig(**pruning_config) - logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + logger.debug( + f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") else: # 使用默认配置(关闭剪枝) config = None logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") - + pruner = SemanticPruner(config=config, llm_client=llm_client) - + # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None if len(preprocessed_data) == 1 and preprocessed_data[0].context: single_dialog_original_msgs = len(preprocessed_data[0].context.msgs) preprocessed_data = await pruner.prune_dataset(preprocessed_data) - + # 单对话:打印清洗与剪枝信息 if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 @@ -2071,7 +2080,7 @@ async def get_chunked_dialogs_with_preprocessing( ) else: logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") - + # 保存剪枝后的数据 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( @@ -2084,7 +2093,7 @@ async def get_chunked_dialogs_with_preprocessing( logger.error(f"保存剪枝结果失败:{se}") except Exception as e: logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") - + # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( preprocessed_data, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 443ee36a..5e39ba36 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -188,7 +188,6 @@ async def _process_chunk_summary( response_model=MemorySummaryResponse, ) summary_text = structured.summary.strip() - # Generate title and type for the summary title = None episodic_type = None diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d4e1b488..60f1257e 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -374,7 +374,9 @@ class VariablePool: self.variables = deepcopy(pool.variables) def is_file_variable(self, selector): - variable_struct = self._get_variable_struct(selector) + variable_struct = self.get_instance(selector, default=None, strict=False) + if variable_struct is None: + return False if isinstance(variable_struct, FileVariable): return True elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0e3fecee..7f2b8aa6 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -623,7 +623,6 @@ class BaseNode(ABC): async def process_message( api_config: ModelInfo, content: str | dict | FileObject, - end_user_id: str, enable_file=False ) -> list | str | None: provider = api_config.provider @@ -642,8 +641,8 @@ class BaseNode(ABC): return content elif isinstance(content, FileObject): - if content.content_cache.get(provider): - return content.content_cache[provider] + if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): + return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] with get_db_read() as db: multimodel_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( @@ -655,12 +654,11 @@ class BaseNode(ABC): ) file_obj.set_content(content.get_content()) message = await multimodel_service.process_files( - end_user_id, [file_obj], ) content.set_content(file_obj.get_content()) if message: - content.content_cache[provider] = message + content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message return message return None raise TypeError(f'Unexpect input value type - {type(content)}') diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index b293d1f4..66a0f1ac 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -144,7 +144,6 @@ class LLMNode(BaseNode): f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages - if messages_config: # 使用 LangChain 消息格式 messages = [] @@ -153,7 +152,6 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ @@ -161,32 +159,31 @@ class LLMNode(BaseNode): "content": await self.process_message( model_info, content, - user_id, self.typed_config.vision, ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -200,7 +197,7 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(model_info, file, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( @@ -210,7 +207,6 @@ class LLMNode(BaseNode): message["content"] = await self.process_message( model_info, message["content"], - user_id, self.typed_config.vision ) history_message.append(message) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 82363056..cbdad0fa 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode): write_message_task.delay( end_user_id=end_user_id, message=messages, + file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 1095a386..616f7f3a 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -30,6 +30,9 @@ class MemoryConfig(Base): llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") + vision_id = Column(String, nullable=True, comment="视觉模型配置ID") + audio_id = Column(String, nullable=True, comment="语音模型配置ID") + video_id = Column(String, nullable=True, comment="视频模型配置ID") # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 5c2f81a7..6fb41914 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -9,21 +9,22 @@ Classes: """ import uuid -from uuid import UUID from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session + from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.memory_config_model import MemoryConfig +from app.models.workspace_model import Workspace from app.schemas.memory_storage_schema import ( - ConfigKey, ConfigParamsCreate, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc, select -from sqlalchemy.orm import Session - from app.utils.config_utils import resolve_config_id # 获取数据库专用日志器 @@ -157,7 +158,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -491,7 +492,10 @@ class MemoryConfigRepository: raise @staticmethod - def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: + def get_config_with_workspace( + db: Session, + config_id: uuid.UUID | int | str + ) -> Optional[tuple[MemoryConfig, Workspace]]: """Get memory config and its associated workspace information Args: @@ -506,8 +510,6 @@ class MemoryConfigRepository: """ import time - from app.models.workspace_model import Workspace - start_time = time.time() config_id = resolve_config_id(config_id, db) @@ -594,7 +596,7 @@ class MemoryConfigRepository: db_logger.debug( f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") - return (config, workspace) + return config, workspace except ValueError: # Re-raise known business exceptions @@ -630,7 +632,7 @@ class MemoryConfigRepository: List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ from app.models.ontology_scene import OntologyScene - + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: @@ -694,7 +696,7 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 默认配置对象,不存在则返回None """ db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}") - + try: # 优先查找显式标记为默认的配置 stmt = ( @@ -706,13 +708,13 @@ class MemoryConfigRepository: ) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"找到默认配置: config_id={config.config_id}") return config - + # 回退:获取最早创建的活跃配置 stmt = ( select(MemoryConfig) @@ -723,25 +725,25 @@ class MemoryConfigRepository: .order_by(MemoryConfig.created_at.asc()) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}") else: db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}") - + return config - + except Exception as e: db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}") raise @staticmethod def get_with_fallback( - db: Session, - config_id: Optional[uuid.UUID], - workspace_id: uuid.UUID + db: Session, + config_id: Optional[uuid.UUID], + workspace_id: uuid.UUID ) -> Optional[MemoryConfig]: """获取记忆配置,支持回退到工作空间默认配置 @@ -756,19 +758,18 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 配置对象,如果都不存在则返回None """ db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}") - + if not config_id: db_logger.debug("config_id 为空,使用工作空间默认配置") return MemoryConfigRepository.get_workspace_default(db, workspace_id) - + config = db.get(MemoryConfig, config_id) - + if config: return config - + db_logger.warning( f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}" ) - - return MemoryConfigRepository.get_workspace_default(db, workspace_id) + return MemoryConfigRepository.get_workspace_default(db, workspace_id) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..3a017089 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,7 +1,8 @@ from typing import List, Optional -from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode +from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ + MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result + async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. @@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC print(f"Error creating statement nodes: {e}") return None + async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. @@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> return None - -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: +async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ + List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) - + result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened @@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector return None +async def add_perceptual_nodes( + perceptuals: list, + connector: Neo4jConnector, + embedder_client=None, +) -> Optional[List[str]]: + """Add perceptual memory nodes to Neo4j in batch. + + Args: + perceptuals: List of MemoryPerceptualModel objects from PostgreSQL + connector: Neo4j connector instance + embedder_client: Optional embedder client for generating summary embeddings + + Returns: + List of created node UUIDs or None if failed + """ + if not perceptuals: + print("No perceptual nodes to add") + return [] + + try: + flattened = [] + for p in perceptuals: + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if embedder_client and p.summary: + try: + summary_embedding = (await embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + flattened.append({ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "summary_embedding": summary_embedding, + }) + + result = await connector.execute_query( + PERCEPTUAL_NODE_SAVE, + perceptuals=flattened, + ) + created_uuids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") + return created_uuids + + except Exception as e: + print(f"Failed to save Perceptual nodes to Neo4j: {e}") + return None + + +async def add_perceptual_dialogue_edges( + perceptuals: list, + dialog_id: str, + connector: Neo4jConnector, +) -> Optional[List[str]]: + """Add edges between Perceptual nodes and Dialogue nodes. + + Args: + perceptuals: List of MemoryPerceptualModel objects + dialog_id: The dialogue ID (or ref_id) to link to + connector: Neo4j connector instance + + Returns: + List of created edge element IDs or None if failed + """ + if not perceptuals or not dialog_id: + return [] + + try: + edges = [] + for p in perceptuals: + edges.append({ + "perceptual_id": str(p.id), + "dialog_id": dialog_id, + "end_user_id": str(p.end_user_id), + "created_at": p.created_time.isoformat() if p.created_time else None, + }) + + result = await connector.execute_query( + PERCEPTUAL_DIALOGUE_EDGE_SAVE, + edges=edges, + ) + created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") + return created_ids + + except Exception as e: + print(f"Failed to save Perceptual-Dialogue edges: {e}") + return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0ac7dcb1..49dbe2a5 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1323,3 +1323,36 @@ RETURN s.statement AS statement, ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit """ + +# 感知记忆节点保存 +PERCEPTUAL_NODE_SAVE = """ +UNWIND $perceptuals AS p +MERGE (n:Perceptual {id: p.id}) +SET n += { + id: p.id, + end_user_id: p.end_user_id, + perceptual_type: p.perceptual_type, + file_path: p.file_path, + file_name: p.file_name, + file_ext: p.file_ext, + summary: p.summary, + keywords: p.keywords, + topic: p.topic, + domain: p.domain, + created_at: p.created_at, + summary_embedding: p.summary_embedding +} +RETURN n.id AS uuid +""" + +# 感知记忆与对话的关联边 +PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +UNWIND $edges AS edge +MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) +MATCH (d:Dialogue {end_user_id: edge.end_user_id}) +WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id +MERGE (d)-[r:HAS_PERCEPTUAL]->(p) +SET r.end_user_id = edge.end_user_id, + r.created_at = edge.created_at +RETURN elementId(r) AS uuid +""" diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 8d7490fe..e186e54b 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -387,6 +387,12 @@ class MemoryConfig: rerank_model_id: Optional[UUID] = None rerank_model_name: Optional[str] = None + video_model_id: Optional[UUID] = None + video_model_name: Optional[str] = None + vision_model_id: Optional[UUID] = None + vision_model_name: Optional[str] = None + audio_model_id: Optional[UUID] = None + audio_model_name: Optional[str] = None llm_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..98f93408 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -141,7 +141,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -339,7 +339,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态),同时并行启动 TTS diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..f7331851 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -600,7 +600,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -836,7 +836,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 514cb12f..875f02bb 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID import redis -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session +from app.cache import InterestMemoryCache from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import ( merge_multiple_search_results, reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas import FileInput from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from app.services.memory_perceptual_service import MemoryPerceptualService try: from app.core.memory.utils.log.audit_logger import audit_logger @@ -271,6 +274,7 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], + file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -283,6 +287,7 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write + files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -342,48 +347,52 @@ class MemoryAgentService: raise ValueError(error_msg) + perceptual_serivce = MemoryPerceptualService(db) + file_content = [] + for message in file_messages: + for file in message["files"]: + file_object = await perceptual_serivce.generate_perceptual_memory( + end_user_id=end_user_id, + memory_config=memory_config, + file=FileInput(**file) + ) + file_content.append(file_object) + + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: if storage_type == "rag": # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) await write_rag(end_user_id, message_text, user_rag_memory_id) return "success" else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config, - "language": language - } - - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, - contents) + await write_neo4j( + end_user_id=end_user_id, + messages=messages, + file_content=file_content, + memory_config=memory_config, + ref_id='', + language=language + ) + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id, lang + ) + if deleted: + logger.info( + f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 01bc6267..9a0fb8ed 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -28,7 +28,7 @@ class MemoryAPIService: 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ - + def __init__(self, db: Session): """Initialize MemoryAPIService. @@ -36,11 +36,11 @@ class MemoryAPIService: db: SQLAlchemy database session """ self.db = db - + def validate_end_user( - self, - end_user_id: str, - workspace_id: uuid.UUID + self, + end_user_id: str, + workspace_id: uuid.UUID ) -> EndUser: """Validate that end_user exists and belongs to the workspace. @@ -56,7 +56,7 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace """ logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}") - + # Query end_user by ID try: end_user_uuid = uuid.UUID(end_user_id) @@ -66,7 +66,7 @@ class MemoryAPIService: message=f"Invalid end_user_id format: {end_user_id}", code=BizCode.INVALID_PARAMETER ) - + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() if not end_user: @@ -75,13 +75,13 @@ class MemoryAPIService: resource_type="EndUser", resource_id=end_user_id ) - + # Verify end_user belongs to the workspace via App relationship app = self.db.query(App).filter( App.id == end_user.app_id, App.is_active.is_(True) ).first() - + if not app: logger.warning(f"App not found for end_user: {end_user_id}") # raise ResourceNotFoundException( @@ -99,7 +99,7 @@ class MemoryAPIService: # message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}", # code=BizCode.FORBIDDEN # ) - + logger.info(f"End user {end_user_id} validated successfully") return end_user @@ -125,13 +125,14 @@ class MemoryAPIService: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") async def write_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - config_id: str, - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + files: Optional[list]=None, + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -153,14 +154,16 @@ class MemoryAPIService: ResourceNotFoundException: If end_user not found BusinessException: If end_user not in authorized workspace or write fails """ + if files is None: + files = list() logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - + try: # Delegate to MemoryAgentService # Convert string message to list[dict] format expected by MemoryAgentService @@ -171,11 +174,12 @@ class MemoryAPIService: config_id=config_id, db=self.db, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id or "" + user_rag_memory_id=user_rag_memory_id or "", + files=files ) - + logger.info(f"Memory write successful for end_user: {end_user_id}") - + # result may be a string "success" or a dict with a "status" key # Preserve the full dict so callers don't silently lose extra fields # (e.g. error codes, metadata) returned by MemoryAgentService. @@ -189,7 +193,7 @@ class MemoryAPIService: "status": result if isinstance(result, str) else "success", "end_user_id": end_user_id, } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -204,16 +208,16 @@ class MemoryAPIService: message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - + async def read_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - search_switch: str = "0", - config_id: str = "", - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Read memory with validation. @@ -237,14 +241,13 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace or read fails """ logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( @@ -257,15 +260,15 @@ class MemoryAPIService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "" ) - + logger.info(f"Memory read successful for end_user: {end_user_id}") - + return { "answer": result.get("answer", ""), "intermediate_outputs": result.get("intermediate_outputs", []), "end_user_id": end_user_id } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -282,8 +285,8 @@ class MemoryAPIService: ) def list_memory_configs( - self, - workspace_id: uuid.UUID, + self, + workspace_id: uuid.UUID, ) -> Dict[str, Any]: """List all memory configs for a workspace. diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index a3751c07..1a4af531 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None): """Validate configuration ID format (supports both UUID and integer).""" if isinstance(config_id, uuid.UUID): return config_id - + if config_id is None: raise InvalidConfigError( "Configuration ID cannot be None", @@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None): if result: logger.info(f"Found config_id {result.config_id} for user_id {config_id}") return result.config_id - + return config_id if isinstance(config_id, str): config_id_stripped = config_id.strip() - + # Try parsing as UUID first try: return uuid.UUID(config_id_stripped) except ValueError: pass - + # Fall back to integer parsing try: parsed_id = int(config_id_stripped) @@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - + # 如果提供了数据库会话,尝试通过 user_id 查询 config_id if db is not None: # 查询 user_id 匹配的记录 stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id)) result = db.execute(stmt).scalars().first() - + if result: logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}") return result.config_id - + return parsed_id except ValueError: raise InvalidConfigError( @@ -154,10 +154,10 @@ class MemoryConfigService: self.db = db def load_memory_config( - self, - config_id: Optional[UUID] = None, - workspace_id: Optional[UUID] = None, - service_name: str = "MemoryConfigService", + self, + config_id: Optional[UUID] = None, + workspace_id: Optional[UUID] = None, + service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database with optional fallback. @@ -194,14 +194,14 @@ class MemoryConfigService: try: # Use get_config_with_fallback if workspace_id is provided memory_config = None + validated_config_id = None if workspace_id: - validated_config_id = None if config_id: try: validated_config_id = _validate_config_id(config_id, self.db) except Exception: validated_config_id = None - + memory_config = self.get_config_with_fallback( memory_config_id=validated_config_id, workspace_id=workspace_id @@ -210,7 +210,7 @@ class MemoryConfigService: validated_config_id = _validate_config_id(config_id, self.db) from app.models.memory_config_model import MemoryConfig as MemoryConfigModel memory_config = self.db.get(MemoryConfigModel, validated_config_id) - + if not memory_config: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -233,7 +233,7 @@ class MemoryConfigService: result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id) db_query_time = time.time() - db_query_start logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") - + if not result: raise ConfigurationError( f"Workspace not found for config {memory_config.config_id}" @@ -243,10 +243,10 @@ class MemoryConfigService: # Helper function to validate model with workspace fallback def _validate_model_with_fallback( - model_id: str, - model_type: str, - workspace_default: str, - required: bool = False + model_id: str, + model_type: str, + workspace_default: str, + required: bool = False ) -> tuple: """Validate model ID, falling back to workspace default if invalid. @@ -275,7 +275,7 @@ class MemoryConfigService: logger.warning( f"{model_type} model validation failed, trying workspace default: {e}" ) - + # Fallback to workspace default if workspace_default: try: @@ -297,7 +297,7 @@ class MemoryConfigService: logger.error(f"Workspace default {model_type} model also invalid: {e}") if required: raise - + if required: raise InvalidConfigError( f"{model_type.title()} model is required but not configured", @@ -306,7 +306,7 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id ) - + return None, None # Step 2: Validate embedding model with workspace fallback @@ -343,6 +343,35 @@ class MemoryConfigService: if memory_config.rerank_id or workspace.rerank: logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") + vision_uuid, vision_name = validate_and_resolve_model_id( + memory_config.vision_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + audio_uuid, audio_name = validate_and_resolve_model_id( + memory_config.audio_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + video_uuid, video_name = validate_and_resolve_model_id( + memory_config.video_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) # Create immutable MemoryConfig object config = MemoryConfig( config_id=memory_config.config_id, @@ -356,6 +385,12 @@ class MemoryConfigService: embedding_model_name=embedding_name, rerank_model_id=rerank_uuid, rerank_model_name=rerank_name, + video_model_id=video_uuid, + video_model_name=video_name, + vision_model_id=vision_uuid, + vision_model_name=vision_name, + audio_model_id=audio_uuid, + audio_model_name=audio_name, storage_type=workspace.storage_type or "neo4j", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", reflexion_enabled=memory_config.enable_self_reflexion or False, @@ -364,24 +399,31 @@ class MemoryConfigService: reflexion_baseline=memory_config.baseline or "Time", loaded_at=datetime.now(), # Pipeline config: Deduplication - enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, - enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, + enable_llm_dedup_blockwise=bool( + memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, + enable_llm_disambiguation=bool( + memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, # Pipeline config: Statement extraction - statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, - include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, - max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, + statement_granularity=int( + memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, + include_dialogue_context=bool( + memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, + max_dialogue_context_chars=int( + memory_config.max_context) if memory_config.max_context is not None else 1000, # Pipeline config: Forgetting engine lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, # Pipeline config: Pruning - pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, + pruning_enabled=bool( + memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_scene=memory_config.pruning_scene or "education", - pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, + pruning_threshold=float( + memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), @@ -448,9 +490,9 @@ class MemoryConfigService: if not config: logger.warning(f"Model ID {model_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -481,9 +523,9 @@ class MemoryConfigService: if not config: logger.warning(f"Embedding model ID {embedding_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -571,25 +613,25 @@ class MemoryConfigService: """ from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.repositories.ontology_class_repository import OntologyClassRepository - + if not memory_config.scene_id: logger.debug("No scene_id configured, skipping ontology type fetch") return None - + try: ontology_repo = OntologyClassRepository(self.db) ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id) - + if not ontology_classes: logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}") return None - + ontology_types = OntologyTypeList.from_db_models(ontology_classes) logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" ) return ontology_types - + except Exception as e: logger.warning( f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}", @@ -598,8 +640,8 @@ class MemoryConfigService: return None def get_workspace_default_config( - self, - workspace_id: UUID + self, + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get workspace default memory config. @@ -613,19 +655,19 @@ class MemoryConfigService: Optional[MemoryConfigModel]: Default config or None if no configs exist """ config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id) - + if not config: logger.warning( "No active memory config found for workspace fallback", extra={"workspace_id": str(workspace_id)} ) - + return config def get_config_with_fallback( - self, - memory_config_id: Optional[UUID], - workspace_id: UUID + self, + memory_config_id: Optional[UUID], + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get memory config with fallback to workspace default. @@ -644,13 +686,13 @@ class MemoryConfigService: "No memory config ID provided, using workspace default", extra={"workspace_id": str(workspace_id)} ) - + config = MemoryConfigRepository.get_with_fallback( self.db, memory_config_id, workspace_id ) - + if not config and memory_config_id: logger.warning( "Memory config not found, falling back to workspace default", @@ -659,13 +701,13 @@ class MemoryConfigService: "workspace_id": str(workspace_id) } ) - + return config def delete_config( - self, - config_id: UUID | int, - force: bool = False + self, + config_id: UUID | int, + force: bool = False ) -> dict: """Delete memory config with protection against in-use configs. @@ -687,7 +729,7 @@ class MemoryConfigService: from app.core.exceptions import ResourceNotFoundException from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.repositories.end_user_repository import EndUserRepository - + # 处理旧格式 int 类型的 config_id if isinstance(config_id, int): logger.warning( @@ -699,11 +741,11 @@ class MemoryConfigService: "message": "旧格式配置ID不支持删除操作,请使用新版配置", "legacy_int_id": config_id } - + config = self.db.get(MemoryConfigModel, config_id) if not config: raise ResourceNotFoundException("MemoryConfig", str(config_id)) - + # Check if this is the default config - default configs cannot be deleted if config.is_default: logger.warning( @@ -715,11 +757,11 @@ class MemoryConfigService: "message": "默认配置不允许删除", "is_default": True } - + # Use repository to count connected end users end_user_repo = EndUserRepository(self.db) connected_count = end_user_repo.count_by_memory_config_id(config_id) - + if connected_count > 0 and not force: logger.warning( "Attempted to delete memory config with connected end users", @@ -728,18 +770,18 @@ class MemoryConfigService: "connected_count": connected_count } ) - + return { "status": "warning", "message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置", "connected_count": connected_count, "force_required": True } - + # Force delete: use repository to clear end user references first if connected_count > 0 and force: cleared_count = end_user_repo.clear_memory_config_id(config_id) - + logger.warning( "Force deleting memory config, clearing end user references", extra={ @@ -747,11 +789,11 @@ class MemoryConfigService: "cleared_end_users": cleared_count } ) - + try: self.db.delete(config) self.db.commit() - + logger.info( "Memory config deleted", extra={ @@ -760,16 +802,16 @@ class MemoryConfigService: "affected_users": connected_count } ) - + return { "status": "success", "message": "记忆配置删除成功", "affected_users": connected_count } - + except IntegrityError as e: self.db.rollback() - + # Handle foreign key violation gracefully error_str = str(e.orig) if e.orig else str(e) if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower(): @@ -785,7 +827,7 @@ class MemoryConfigService: "message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除", "force_required": True } - + # Re-raise other integrity errors logger.error( "Delete failed due to integrity error", @@ -800,9 +842,9 @@ class MemoryConfigService: # ==================== 记忆配置提取方法 ==================== def extract_memory_config_id( - self, - app_type: str, - config: dict + self, + app_type: str, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(根据应用类型分发) @@ -828,8 +870,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_agent( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Agent 应用配置中提取 memory_config_id @@ -888,8 +930,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_workflow( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Workflow 应用配置中提取 memory_config_id @@ -905,14 +947,14 @@ class MemoryConfigService: - is_legacy_int: 是否检测到旧格式 int 数据 """ nodes = config.get("nodes", []) - + for node in nodes: node_type = node.get("type", "") - + # 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite) if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]: config_id = node.get("config", {}).get("config_id") - + if config_id: try: # 处理字符串、UUID 和 int(旧数据兼容)三种情况 @@ -937,6 +979,6 @@ class MemoryConfigService: f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " f"node_type={node_type}, error={str(e)}" ) - + logger.debug("工作流配置中未找到记忆节点") return None, False diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8a7c86e2..d6c1de87 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -12,11 +12,12 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models import FileMetadata +from app.models import FileMetadata, ModelApiKey, ModelType from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository -from app.schemas import FileType +from app.schemas import FileType, FileInput +from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, @@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import ( AudioModal, Content, VideoModal, TextModal ) from app.schemas.model_schema import ModelInfo +from app.services.model_service import ModelApiKeyService +from app.services.multimodal_service import MultimodalService business_logger = get_business_logger() @@ -195,21 +198,58 @@ class MemoryPerceptualService: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + def _get_mutlimodal_client( + self, + file_type: FileType, + config: MemoryConfig + ) -> tuple[RedBearLLM | None, ModelApiKey | None]: + model_config = None + if file_type == FileType.AUDIO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.audio_model_id + ) + elif file_type == FileType.VIDEO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.video_model_id + ) + elif file_type == FileType.DOCUMENT: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.llm_model_id + ) + elif file_type == FileType.IMAGE: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.vision_model_id + ) + llm = None + if model_config: + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ) + ) + return llm, model_config + async def generate_perceptual_memory( self, end_user_id: str, - model_config: ModelInfo, - file_type: str, - file_url: str, - file_message: dict, + memory_config: MemoryConfig, + file: FileInput ): - memories = self.repository.get_by_url(file_url) + memories = self.repository.get_by_url(file.url) if memories: - business_logger.info(f"Perceptual memory already exists: {file_url}") + business_logger.info(f"Perceptual memory already exists: {file.url}") if end_user_id not in [memory.end_user_id for memory in memories]: business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") memory_cache = memories[0] - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), perceptual_type=PerceptualType(memory_cache.perceptual_type), file_path=memory_cache.file_path, @@ -219,20 +259,31 @@ class MemoryPerceptualService: meta_data=memory_cache.meta_data ) self.db.commit() - - return - llm = RedBearLLM(RedBearModelConfig( + return memory + else: + for memory in memories: + if memory.end_user_id == uuid.UUID(end_user_id): + return memory + llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, api_key=model_config.api_key, - base_url=model_config.api_base, - is_omni=model_config.is_omni - ), type=model_config.model_type) + api_base=model_config.api_base, + is_omni=model_config.is_omni, + capability=model_config.capability, + model_type=ModelType.LLM + )) + file_message = await multimodel_service.process_files( + files=[file] + ) + if file_message: + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) messages = [ @@ -242,8 +293,22 @@ class MemoryPerceptualService: ]} ] result = await llm.ainvoke(messages) - content = json_repair.repair_json(result.content, return_objects=True) - path = urlparse(file_url).path + content = result.content + final_output = "" + if isinstance(content, list): + for msg in content: + if isinstance(msg, dict): + final_output += msg.get("text", "") + elif isinstance(msg, str): + final_output += msg + elif isinstance(content, dict): + final_output += content.get("text", "") + elif isinstance(content, str): + final_output = content + else: + raise ValueError(f"Unexcept Model Output Type: {result.content}") + content = json_repair.repair_json(final_output, return_objects=True) + path = urlparse(file.url).path filename = os.path.basename(path) filename = unquote(filename) file_ext = os.path.splitext(filename)[1] @@ -260,13 +325,13 @@ class MemoryPerceptualService: except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: - if file_type == FileType.AUDIO: + if file.type == FileType.AUDIO: file_ext = ".mp3" - elif file_type == FileType.VIDEO: + elif file.type == FileType.VIDEO: file_ext = ".mp4" - elif file_type == FileType.DOCUMENT: + elif file.type == FileType.DOCUMENT: file_ext = ".txt" - elif file_type == FileType.IMAGE: + elif file.type == FileType.IMAGE: file_ext = ".jpg" filename += file_ext file_content = { @@ -274,11 +339,11 @@ class MemoryPerceptualService: "topic": content.get("topic"), "domain": content.get("domain") } - if file_type in [FileType.IMAGE, FileType.VIDEO]: + if file.type in [FileType.IMAGE, FileType.VIDEO]: file_modalities = { "scene": content.get("scene", []) } - elif file_type in [FileType.DOCUMENT]: + elif file.type in [FileType.DOCUMENT]: file_modalities = { "section_count": content.get("section_count", 0), "title": content.get("title", ""), @@ -288,10 +353,10 @@ class MemoryPerceptualService: file_modalities = { "speaker_count": content.get("speaker_count", 0) } - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType.trans_from_file_type(file_type), - file_path=file_url, + perceptual_type=PerceptualType.trans_from_file_type(file.type), + file_path=file.url, file_name=filename, file_ext=file_ext, summary=content.get('summary', ""), @@ -301,3 +366,4 @@ class MemoryPerceptualService: } ) self.db.commit() + return memory diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index f0c7cee2..eb8df242 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,14 +9,12 @@ - OpenAI: 支持 URL 和 base64 格式 """ import base64 +import csv import io -import uuid +import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional -import csv -import json - import PyPDF2 import httpx import magic @@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService -from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -342,15 +339,12 @@ class MultimodalService: async def process_files( self, - end_user_id: uuid.UUID | str, files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: - end_user_id: 用户ID files: 文件输入列表 Returns: @@ -358,8 +352,6 @@ class MultimodalService: """ if not files: return [] - if isinstance(end_user_id, uuid.UUID): - end_user_id = str(end_user_id) # 获取对应的策略 # dashscope 的 omni 模型使用 OpenAI 兼容格式 @@ -380,23 +372,15 @@ class MultimodalService: if file.type == FileType.IMAGE and "vision" in self.capability: is_support, content = await self._process_image(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: is_support, content = await self._process_document(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.AUDIO and "audio" in self.capability: is_support, content = await self._process_audio(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.VIDEO and "video" in self.capability: is_support, content = await self._process_video(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -418,17 +402,6 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - def write_perceptual_memory( - self, - end_user_id: str, - file_type: str, - file_url: str, - file_message: dict - ): - """写入感知记忆""" - if end_user_id and self.api_config: - write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) - async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 diff --git a/api/app/tasks.py b/api/app/tasks.py index c37e564e..8afb2194 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,12 +1080,14 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, + file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write + file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1097,6 +1099,8 @@ def write_message_task( Raises: Exception on failure """ + if file_messages is None: + file_messages = [] logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " @@ -1142,7 +1146,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result From 6bba574ca64ff23570826b34e61246e6334918c0 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 13:54:15 +0800 Subject: [PATCH 5/8] feat(memory, model): update multi-modal memory write and model list API - Adjust multi-modal memory write behavior for text and visual data - Mask API keys in model list response to prevent exposure - Add capability-based filtering to the model list API --- api/app/controllers/model_controller.py | 12 ++ .../controllers/user_memory_controllers.py | 108 ++++++++------ .../core/memory/agent/utils/get_dialogs.py | 5 +- .../core/memory/agent/utils/write_tools.py | 40 +---- .../core/memory/llm_tools/chunker_client.py | 10 +- api/app/core/memory/models/graph_models.py | 33 +++-- api/app/core/memory/models/message_models.py | 4 +- .../extraction_orchestrator.py | 80 +++++++++- api/app/core/workflow/nodes/memory/node.py | 9 +- api/app/models/models_model.py | 5 +- .../repositories/memory_config_repository.py | 3 + api/app/repositories/model_repository.py | 12 +- api/app/repositories/neo4j/add_nodes.py | 139 +++--------------- api/app/repositories/neo4j/cypher_queries.py | 109 +++++++------- api/app/repositories/neo4j/graph_saver.py | 72 ++++++--- api/app/schemas/model_schema.py | 38 +++-- api/app/services/memory_agent_service.py | 39 ++--- api/app/services/memory_perceptual_service.py | 8 +- api/app/services/pilot_run_service.py | 3 + api/app/services/user_memory_service.py | 3 +- api/app/tasks.py | 58 +------- 21 files changed, 389 insertions(+), 401 deletions(-) diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 6204a745..71fd41ad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -42,6 +42,7 @@ def get_model_strategies(): @router.get("", response_model=ApiResponse) def get_model_list( type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -74,10 +75,21 @@ def get_model_list( unique_flat_type = list(dict.fromkeys(flat_type)) type_list = [ModelType(t.lower()) for t in unique_flat_type] + capability_list = [] + if capability is not None: + flat_capability = [] + for item in capability: + split_items = [c.strip() for c in item.split(', ') if c.strip()] + flat_capability.extend(split_items) + + unique_flat_capability = list(dict.fromkeys(flat_capability)) + capability_list = unique_flat_capability + api_logger.error(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQuery( type=type_list, provider=provider, + capability=capability_list, is_active=is_active, is_public=is_public, search=search, diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index be796ff9..3ce1df6e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -5,7 +5,7 @@ from typing import Optional import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends,Header +from fastapi import APIRouter, Depends, Header from app.db import get_db from app.core.language_utils import get_language_from_header @@ -19,7 +19,7 @@ from app.services.user_memory_service import ( analytics_graph_data, analytics_community_graph_data, ) -from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction +from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository @@ -45,9 +45,9 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的记忆洞察报告 @@ -73,10 +73,10 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的用户摘要 @@ -90,7 +90,7 @@ async def get_user_summary_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -102,7 +102,7 @@ async def get_user_summary_api( api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) + result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -117,10 +117,10 @@ async def get_user_summary_api( @router.post("/analytics/generate_cache", response_model=ApiResponse) async def generate_cache_api( - request: GenerateCacheRequest, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + request: GenerateCacheRequest, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 手动触发缓存生成 @@ -134,7 +134,7 @@ async def generate_cache_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -155,10 +155,12 @@ async def generate_cache_api( api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") # 生成记忆洞察 - insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) + insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, + language=language) # 生成用户摘要 - summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) + summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, + language=language) # 构建响应 result = { @@ -209,9 +211,9 @@ async def generate_cache_api( @router.get("/analytics/node_statistics", response_model=ApiResponse) async def get_node_statistics_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -220,7 +222,8 @@ async def get_node_statistics_api( api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + api_logger.info( + f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") try: # 调用新的记忆类型统计函数 @@ -228,21 +231,23 @@ async def get_node_statistics_api( # 计算总数用于日志 total_count = sum(item["count"] for item in result) - api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") + api_logger.info( + f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) + @router.get("/analytics/graph_data", response_model=ApiResponse) async def get_graph_data_api( - end_user_id: str, - node_types: Optional[str] = None, - limit: int = 100, - depth: int = 1, - center_node_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -298,9 +303,9 @@ async def get_graph_data_api( @router.get("/analytics/community_graph", response_model=ApiResponse) async def get_community_graph_data_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -334,9 +339,9 @@ async def get_community_graph_data_api( @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) @@ -385,9 +390,9 @@ async def get_end_user_profile( @router.post("/updated_end_user/profile", response_model=ApiResponse) async def update_end_user_profile( - profile_update: EndUserProfileUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 更新终端用户的基本信息 @@ -417,7 +422,7 @@ async def update_end_user_profile( else: error_msg = result["error"] api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") - + # 根据错误类型映射到合适的业务错误码 if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) @@ -427,15 +432,18 @@ async def update_end_user_profile( # 只有未预期的错误才使用 INTERNAL_ERROR return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) + @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): +async def memory_space_timeline_of_shared_memories( + id: str, label: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): # 使用集中化的语言校验 language = get_language_from_header(language_type) - - workspace_id=current_user.current_workspace_id + + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_ timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) return success(data=timeline_memories_result, msg="共同记忆时间线") + + @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): try: api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 3b06defe..4c667061 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -11,7 +11,7 @@ async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", end_user_id: str = "group_1", messages: list = None, - ref_id: str = "wyl_20251027", + ref_id: str = "", config_id: str = None ) -> List[DialogData]: """Generate chunks from structured messages using the specified chunker strategy. @@ -40,12 +40,13 @@ async def get_chunked_dialogs( role = msg['role'] content = msg['content'] + files = msg.get("file_content", []) if role not in ['user', 'assistant']: raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") if content.strip(): - conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) if not conversation_messages: raise ValueError("Message list cannot be empty after filtering") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 147a0316..413f54da 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,8 +5,8 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio -import uuid import time +import uuid from datetime import datetime from dotenv import load_dotenv @@ -19,10 +19,8 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context -from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ - add_perceptual_dialogue_edges +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -36,7 +34,6 @@ async def write( end_user_id: str, memory_config: MemoryConfig, messages: list, - file_content: list[MemoryPerceptualModel], ref_id: str = "", language: str = "zh", ) -> None: @@ -47,7 +44,6 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - file_content: mutilmodal message list ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ @@ -142,9 +138,11 @@ async def write( all_chunk_nodes, all_statement_nodes, all_entity_nodes, + all_perceptual_nodes, all_statement_chunk_edges, all_statement_entity_edges, all_entity_entity_edges, + all_perceptual_edges, all_dedup_details, ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) @@ -169,9 +167,11 @@ async def write( chunk_nodes=all_chunk_nodes, statement_nodes=all_statement_nodes, entity_nodes=all_entity_nodes, + perceptual_nodes=all_perceptual_nodes, statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, + perceptual_edges=all_perceptual_edges, connector=neo4j_connector, ) if success: @@ -230,34 +230,6 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) - # Step 5: Save perceptual memory to Neo4j - step_start = time.time() - if file_content: - try: - pc_connector = Neo4jConnector() - try: - created_ids = await add_perceptual_nodes( - perceptuals=file_content, - connector=pc_connector, - embedder_client=embedder_client, - ) - # 如果有 ref_id,建立感知记忆与对话的关联 - if ref_id and created_ids: - await add_perceptual_dialogue_edges( - perceptuals=file_content, - dialog_id=ref_id, - connector=pc_connector, - ) - logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") - finally: - try: - await pc_connector.close() - except Exception: - pass - except Exception as e: - logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) - log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) - # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 93a2df82..51d15aab 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -1,10 +1,10 @@ -from typing import Any, List -import re -import os import asyncio import json -import numpy as np import logging +import os +from typing import Any, List + +import numpy as np # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -246,6 +246,7 @@ class ChunkerClient: "total_sub_chunks": len(sub_chunks), "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) else: @@ -258,6 +259,7 @@ class ChunkerClient: "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index fb251f1f..1b8c9d52 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -114,7 +114,7 @@ class Edge(BaseModel): end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") + expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.") class ChunkEdge(Edge): @@ -175,6 +175,12 @@ class EntityEntityEdge(Edge): return parse_historical_datetime(v) +class PerceptualEdge(Edge): + """Edge connecting perceptual nodes to their source chunks + """ + pass + + class Node(BaseModel): """Base class for all graph nodes in the knowledge graph. @@ -555,19 +561,16 @@ class MemorySummaryNode(Node): ) -class MutlimodalNode(Node): +class PerceptualNode(Node): """Node representing a multimodal message in the knowledge graph. - - Attributes: - dialog_id: ID of the parent dialog - message_id: ID of the message - metadata: Additional message metadata - embedding: Optional embedding vector for the message """ - dialog_id: str = Field(..., description="ID of the parent dialog") - message_id: str = Field(..., description="ID of the message") - summary: str = Field(..., description="The text content of the message") - file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") - file_path: List[str] = Field(..., description="List of file paths for multimodal content") - metadata: dict = Field(default_factory=dict, description="Additional message metadata") - embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") + perceptual_type: int + file_path: str + file_name: str + file_ext: str + summary: str + keywords: list[str] + topic: str + domain: str + file_type: str + summary_embedding: list[float] | None diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 2f8660af..66203067 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -30,6 +30,7 @@ class ConversationMessage(BaseModel): """ role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") + files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True) class TemporalValidityRange(BaseModel): @@ -130,7 +131,8 @@ class Chunk(BaseModel): content: str = Field(..., description="The content of the chunk as a string.") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") - chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") + files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") + chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 6e94a84f..da10c497 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,7 +31,9 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode + StatementNode, + PerceptualEdge, + PerceptualNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -170,9 +172,11 @@ class ExtractionOrchestrator: list[ChunkNode], list[StatementNode], list[ExtractedEntityNode], + list[PerceptualNode], list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[PerceptualEdge], dict ]: """ @@ -259,9 +263,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) = await self._create_nodes_and_edges(dialog_data_list) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) @@ -275,7 +281,16 @@ class ExtractionOrchestrator: # 注意:deduplication 消息已在创建节点和边完成后立即发送 - result = await self._run_dedup_and_write_summary( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + dialog_data_list, + ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, statement_nodes, @@ -287,7 +302,18 @@ class ExtractionOrchestrator: ) logger.info(f"知识提取流水线运行完成({mode_str})") - return result + return ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + perceptual_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + perceptual_edges, + dialog_data_list, + ) except Exception as e: logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) @@ -1000,9 +1026,11 @@ class ExtractionOrchestrator: List[ChunkNode], List[StatementNode], List[ExtractedEntityNode], + List[PerceptualNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge] + List[EntityEntityEdge], + List[PerceptualEdge] ]: """ 创建图数据库节点和边 @@ -1026,6 +1054,8 @@ class ExtractionOrchestrator: statement_chunk_edges = [] statement_entity_edges = [] entity_entity_edges = [] + perceptual_nodes = [] + perceptual_edges = [] # 用于去重的集合 entity_id_set = set() @@ -1069,6 +1099,46 @@ class ExtractionOrchestrator: metadata=chunk.metadata, ) chunk_nodes.append(chunk_node) + logger.error(f"chunk file: {chunk.files}") + + for p, file_type in chunk.files: + + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if self.embedder_client and p.summary: + try: + summary_embedding = (await self.embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + perceptual = PerceptualNode( + name=f"Perceptual_{p.id}", + **{ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "file_type": file_type, + "summary_embedding": summary_embedding, + }) + perceptual_nodes.append(perceptual) + perceptual_edges.append(PerceptualEdge( + source=perceptual.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + )) # 处理每个陈述句 for statement in chunk.statements: @@ -1248,9 +1318,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) async def _run_dedup_and_write_summary( diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index cbdad0fa..a28247e4 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -72,7 +72,6 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] - multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -104,19 +103,15 @@ class MemoryWriteNode(BaseNode): url=file_instence.value.url, file_type=file_instence.value.origin_file_type ).model_dump()) - multimodal_memories.append({ - "role": message.role, - "files": file_info - }) messages.append({ "role": message.role, - "content": self._render_template(content, variable_pool) + "content": self._render_template(content, variable_pool), + "files": file_info }) write_message_task.delay( end_user_id=end_user_id, message=messages, - file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcef..44a844d0 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,10 +2,11 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text +from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY from sqlalchemy.orm import relationship from sqlalchemy.sql import func + from app.db import Base diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 6fb41914..e64d19a3 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -408,6 +408,9 @@ class MemoryConfigRepository: "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, "rerank_id": db_config.rerank_id, + "vision_id": db_config.vision_id, + "audio_id": db_config.audio_id, + "video_id": db_config.video_id, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_disambiguation": db_config.enable_llm_disambiguation, "deep_retrieval": db_config.deep_retrieval, diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f49227d3..fd95c793 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -1,14 +1,15 @@ -from sqlalchemy.orm import Session, joinedload, selectinload -from sqlalchemy import and_, or_, func, desc, select -from typing import List, Optional, Dict, Any, Tuple import uuid +from typing import List, Optional, Dict, Any, Tuple +from sqlalchemy import and_, or_, func, desc +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.schemas.model_schema import ( ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigQuery, ModelConfigQueryNew ) -from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() @@ -137,6 +138,9 @@ class ModelConfigRepository: type_values.append(ModelType.LLM) filters.append(ModelConfig.type.in_(type_values)) + if query.capability: + filters.append(ModelConfig.capability.contains(query.capability)) + if query.is_active is not None: filters.append(ModelConfig.is_active == query.is_active) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 3a017089..a53ca289 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,16 +1,19 @@ from typing import List, Optional +from app.core.logging_config import get_logger from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ - MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE + MEMORY_SUMMARY_NODE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = get_logger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") - print(f"All end_user_id: {end_user_id} node and edge deleted successfully") + logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result @@ -25,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn List of created node UUIDs or None if failed """ if not dialogues: - print("No dialogues to save") + logger.info("No dialogues to save") return [] try: @@ -50,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") + logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: - print(f"Error creating dialogue nodes: {e}") + logger.info(f"Error creating dialogue nodes: {e}") return None @@ -69,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC List of created node UUIDs or None if failed """ if not statements: - print("No statements to save") + logger.info("No statements to save") return [] try: @@ -122,11 +125,11 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} statement nodes") + logger.info(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: - print(f"Error creating statement nodes: {e}") + logger.info(f"Error creating statement nodes: {e}") return None @@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> List of created chunk UUIDs or None if failed """ if not chunks: - print("No chunk nodes to add") + logger.info("No chunk nodes to add") return [] try: @@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} chunk nodes") + logger.info(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: - print(f"Error creating chunk nodes: {e}") + logger.info(f"Error creating chunk nodes: {e}") return None -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ - List[str]]: +async def add_memory_summary_nodes( + summaries: List[MemorySummaryNode], + connector: Neo4jConnector +) -> Optional[List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector List of created summary node ids or None if failed """ if not summaries: - print("No memory summary nodes to add") + logger.info("No memory summary nodes to add") return [] try: @@ -220,110 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") - return None - - -async def add_perceptual_nodes( - perceptuals: list, - connector: Neo4jConnector, - embedder_client=None, -) -> Optional[List[str]]: - """Add perceptual memory nodes to Neo4j in batch. - - Args: - perceptuals: List of MemoryPerceptualModel objects from PostgreSQL - connector: Neo4j connector instance - embedder_client: Optional embedder client for generating summary embeddings - - Returns: - List of created node UUIDs or None if failed - """ - if not perceptuals: - print("No perceptual nodes to add") - return [] - - try: - flattened = [] - for p in perceptuals: - meta = p.meta_data or {} - content_meta = meta.get("content", {}) - - # 生成 summary embedding(如果有 embedder_client) - summary_embedding = None - if embedder_client and p.summary: - try: - summary_embedding = (await embedder_client.response([p.summary]))[0] - except Exception as emb_err: - print(f"Failed to embed perceptual summary: {emb_err}") - - flattened.append({ - "id": str(p.id), - "end_user_id": str(p.end_user_id), - "perceptual_type": p.perceptual_type, - "file_path": p.file_path or "", - "file_name": p.file_name or "", - "file_ext": p.file_ext or "", - "summary": p.summary or "", - "keywords": content_meta.get("keywords", []), - "topic": content_meta.get("topic", ""), - "domain": content_meta.get("domain", ""), - "created_at": p.created_time.isoformat() if p.created_time else None, - "summary_embedding": summary_embedding, - }) - - result = await connector.execute_query( - PERCEPTUAL_NODE_SAVE, - perceptuals=flattened, - ) - created_uuids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") - return created_uuids - - except Exception as e: - print(f"Failed to save Perceptual nodes to Neo4j: {e}") - return None - - -async def add_perceptual_dialogue_edges( - perceptuals: list, - dialog_id: str, - connector: Neo4jConnector, -) -> Optional[List[str]]: - """Add edges between Perceptual nodes and Dialogue nodes. - - Args: - perceptuals: List of MemoryPerceptualModel objects - dialog_id: The dialogue ID (or ref_id) to link to - connector: Neo4j connector instance - - Returns: - List of created edge element IDs or None if failed - """ - if not perceptuals or not dialog_id: - return [] - - try: - edges = [] - for p in perceptuals: - edges.append({ - "perceptual_id": str(p.id), - "dialog_id": dialog_id, - "end_user_id": str(p.end_user_id), - "created_at": p.created_time.isoformat() if p.created_time else None, - }) - - result = await connector.execute_query( - PERCEPTUAL_DIALOGUE_EDGE_SAVE, - edges=edges, - ) - created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") - return created_ids - - except Exception as e: - print(f"Failed to save Perceptual-Dialogue edges: {e}") + logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 49dbe2a5..d70a30e9 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1003,60 +1003,69 @@ RETURN DISTINCT """ Graph_Node_query = """ - MATCH (n:MemorySummary) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 0 AS priority - LIMIT $limit +MATCH (n:MemorySummary) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Dialogue) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT 1 +MATCH (n:Dialogue) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT 1 - UNION ALL +UNION ALL - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT $limit +MATCH (n:Statement) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:ExtractedEntity) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 2 AS priority - LIMIT $limit +MATCH (n:ExtractedEntity) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Chunk) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 3 AS priority - LIMIT $limit +MATCH (n:Chunk) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority +LIMIT $limit - """ +UNION ALL +MATCH (n:Perceptual) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 4 AS priority + +""" # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 @@ -1340,19 +1349,19 @@ SET n += { topic: p.topic, domain: p.domain, created_at: p.created_at, + file_type: p.file_type, summary_embedding: p.summary_embedding } RETURN n.id AS uuid """ # 感知记忆与对话的关联边 -PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +PERCEPTUAL_CHUNK_EDGE_SAVE = """ UNWIND $edges AS edge MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) -MATCH (d:Dialogue {end_user_id: edge.end_user_id}) -WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id -MERGE (d)-[r:HAS_PERCEPTUAL]->(p) -SET r.end_user_id = edge.end_user_id, +MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id}) +MERGE (c)-[r:HAS_PERCEPTUAL]->(p) +ON CREATE SET r.end_user_id = edge.end_user_id, r.created_at = edge.created_at RETURN elementId(r) AS uuid """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..d78dcef6 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import ( StatementNode, ExtractedEntityNode, EntityEntityEdge, + PerceptualNode, + PerceptualEdge, ) import logging + logger = logging.getLogger(__name__) + + async def save_entities_and_relationships( - entity_nodes: List[ExtractedEntityNode], - entity_entity_edges: List[EntityEntityEdge], - connector: Neo4jConnector + entity_nodes: List[ExtractedEntityNode], + entity_entity_edges: List[EntityEntityEdge], + connector: Neo4jConnector ): """Save entities and their relationships using graph models""" all_entities = [entity.model_dump() for entity in entity_nodes] @@ -73,8 +78,8 @@ async def save_entities_and_relationships( async def save_chunk_nodes( - chunk_nodes: List[ChunkNode], - connector: Neo4jConnector + chunk_nodes: List[ChunkNode], + connector: Neo4jConnector ): """Save chunk nodes using graph models""" if not chunk_nodes: @@ -89,8 +94,8 @@ async def save_chunk_nodes( async def save_statement_chunk_edges( - statement_chunk_edges: List[StatementChunkEdge], - connector: Neo4jConnector + statement_chunk_edges: List[StatementChunkEdge], + connector: Neo4jConnector ): """Save statement-chunk edges using graph models""" if not statement_chunk_edges: @@ -118,8 +123,8 @@ async def save_statement_chunk_edges( async def save_statement_entity_edges( - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ): """Save statement-entity edges using graph models""" if not statement_entity_edges: @@ -142,7 +147,7 @@ async def save_statement_entity_edges( if all_se_edges: try: await connector.execute_query( - STATEMENT_ENTITY_EDGE_SAVE, + STATEMENT_ENTITY_EDGE_SAVE, relationships=all_se_edges ) except Exception: @@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List[ChunkNode], statement_nodes: List[StatementNode], entity_nodes: List[ExtractedEntityNode], + perceptual_nodes: List[PerceptualNode], entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], + perceptual_edges: List[PerceptualEdge], connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List of ChunkNode objects to save statement_nodes: List of StatementNode objects to save entity_nodes: List of ExtractedEntityNode objects to save + perceptual_nodes: List of PerceptualNode objects to save entity_edges: List of EntityEntityEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save + perceptual_edges: List of PerceptualEdge objects to save connector: Neo4j connector instance Returns: @@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) dialogue_uuids = [record["uuid"] async for record in result] results['dialogues'] = dialogue_uuids - print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") + logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") # 2. Save all chunk nodes in batch if chunk_nodes: @@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j( results['chunks'] = chunk_uuids logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") + if perceptual_nodes: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE + perceptual_data = [node.model_dump() for node in perceptual_nodes] + result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data) + perceptual_uuids = [record["uuid"] async for record in result] + results["perceptuals"] = perceptual_uuids + logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j") + # 3. Save all statement nodes in batch if statement_nodes: from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE @@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j( results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + if perceptual_edges: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE + perceptual_edge_data = [] + for edge in perceptual_edges: + print(edge.source, edge.target) + perceptual_edge_data.append({ + "perceptual_id": edge.source, + "chunk_id": edge.target, + "end_user_id": edge.end_user_id, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + }) + result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data) + perceptual_edges_uuids = [record["uuid"] async for record in result] + results['perceptual_chunk_edges'] = perceptual_edges_uuids + logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j") + return results try: @@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j( def schedule_clustering_after_write( - entity_nodes: List, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + entity_nodes: List, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 写入 Neo4j 成功后,调度后台聚类任务。 @@ -325,14 +358,15 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id)) async def _trigger_clustering( - new_entity_ids: List[str], - end_user_id: str, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + new_entity_ids: List[str], + end_user_id: str, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 058f082d..668a84a8 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase): updated_at: datetime.datetime api_keys: List["ModelApiKey"] = [] + @staticmethod + def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str: + if not key or len(key) <= prefix + suffix: + return "*" * len(key) + return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:] + @field_validator("api_keys", mode="after") @classmethod def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: @@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase): def _serialize_created_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("api_keys", when_used="json") + def _serialize_api_keys(self, api_keys: List["ModelApiKey"]): + result = [] + for api_key in api_keys: + data = api_key.model_dump() + data["api_key"] = self.mask_api_key(api_key.api_key) + result.append(data) + return result + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase): if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict): self.model_config_ids = [ mc.id for mc in self.model_configs - if hasattr(mc, 'id') - and not getattr(mc, 'is_composite', False) - and getattr(mc, 'name', None) == self.model_name + if hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name ] # 情况2:字典列表 elif isinstance(self.model_configs, list): self.model_config_ids = [ mc['id'] if isinstance(mc, dict) else mc.id for mc in self.model_configs - if ((isinstance(mc, dict) - and 'id' in mc + if ((isinstance(mc, dict) + and 'id' in mc and not mc.get('is_composite', False) - and mc.get('name') == self.model_name) or - (hasattr(mc, 'id') + and mc.get('name') == self.model_name) or + (hasattr(mc, 'id') and not getattr(mc, 'is_composite', False) and getattr(mc, 'name', None) == self.model_name)) ] @@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase): validate_assignment=True # 确保赋值触发校验 ) - @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel): """模型配置查询Schema""" type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") + capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)") is_active: Optional[bool] = Field(None, description="激活状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) @@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel): is_composite: Optional[bool] = Field(None, description="组合模型筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelMarketplace(BaseModel): """模型广场响应Schema""" llm_models: List[ModelConfig] = [] @@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel): class ModelBase(BaseModel): """基础模型Schema""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID name: str type: str @@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelInfo(BaseModel): """模型信息Schema""" model_name: str = Field(..., description="模型名称") @@ -336,4 +353,3 @@ class ModelInfo(BaseModel): is_omni: bool = Field(default=False, description="是否为omni模型") model_type: ModelType = Field(..., description="模型类型") capability: List[str] = Field(default_factory=list, description="模型能力列表") - diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 875f02bb..8bb6538d 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -274,7 +274,6 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], - file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -287,7 +286,6 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write - files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -348,15 +346,15 @@ class MemoryAgentService: raise ValueError(error_msg) perceptual_serivce = MemoryPerceptualService(db) - file_content = [] - for message in file_messages: + for message in messages: + message["file_content"] = [] for file in message["files"]: file_object = await perceptual_serivce.generate_perceptual_memory( end_user_id=end_user_id, memory_config=memory_config, file=FileInput(**file) ) - file_content.append(file_object) + message["file_content"].append((file_object, file["type"])) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: @@ -368,7 +366,6 @@ class MemoryAgentService: await write_neo4j( end_user_id=end_user_id, messages=messages, - file_content=file_content, memory_config=memory_config, ref_id='', language=language @@ -380,19 +377,23 @@ class MemoryAgentService: if deleted: logger.info( f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") - return self.writer_messages_deal( - "success", - start_time, - end_user_id, - config_id, - message_text, - { - "status": "success", - "data": messages, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name - } - ) + for message in messages: + message["file_content"] = [ + perceptual[0].file_path for perceptual in message["file_content"] + ] + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index d6c1de87..8255dbbe 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -317,11 +317,11 @@ class MemoryPerceptualService: stmt = select(FileMetadata).where( FileMetadata.id == file_id ) - file = self.db.execute(stmt).scalar_one_or_none() + file_obj = self.db.execute(stmt).scalar_one_or_none() - if file: - filename = file.file_name - file_ext = file.file_ext + if file_obj: + filename = file_obj.file_name + file_ext = file_obj.file_ext except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index fc749157..4617946b 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -297,9 +297,12 @@ async def run_pilot_extraction( chunk_nodes, statement_nodes, entity_nodes, + _, statement_chunk_edges, statement_entity_edges, entity_edges, + _, + _ ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index d5d19e0d..29516acc 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1888,7 +1888,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ "Chunk": ["content", "created_at"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], - "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段 + "Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"] } # 获取该节点类型的白名单字段 diff --git a/api/app/tasks.py b/api/app/tasks.py index 8afb2194..f243eac3 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,14 +1080,12 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, - file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write - file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1099,9 +1097,6 @@ def write_message_task( Raises: Exception on failure """ - if file_messages is None: - file_messages = [] - logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"config_id={config_id} (type: {type(config_id).__name__}), " @@ -1146,7 +1141,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result @@ -2617,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } -@celery_app.task( - name="app.tasks.write_perceptual_memory", - bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=3600, - soft_time_limit=3300, -) -def write_perceptual_memory( - self, - end_user_id: str, - model_api_config: dict, - file_type: str, - file_url: str, - file_message: dict -): - """ - Write perceptual memory for a user into PostgreSQL and Neo4j. - - This task generates or updates the user's perceptual memory - in the backend databases. It is intended to be executed asynchronously - via Celery. - - Args: - end_user_id (uuid.UUID): The unique identifier of the end user. - model_api_config (ModelInfo): API configuration for the model - used to generate perceptual memory. - file_type (str): The file type - file_url (url): The url of file - file_message (dict): The file message containing details about the file - to be processed. - - Returns: - None - """ - file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() - set_asyncio_event_loop() - with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): - model_info = ModelInfo(**model_api_config) - with get_db_context() as db: - memory_perceptual_service = MemoryPerceptualService(db) - return asyncio.run(memory_perceptual_service.generate_perceptual_memory( - end_user_id, - model_info, - file_type, - file_url, - file_message, - )) - - # ============================================================================= # 社区聚类补全任务(触发型) # ============================================================================= From de6e2f54d2d773969c36a3658791668171034fc4 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 14:30:07 +0800 Subject: [PATCH 6/8] fix(perceptual): prevent errors when writing unsupported modalities --- api/app/services/memory_agent_service.py | 2 ++ api/app/services/memory_api_service.py | 4 ---- api/app/services/memory_perceptual_service.py | 6 ++++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8bb6538d..3cfcc1d6 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -354,6 +354,8 @@ class MemoryAgentService: memory_config=memory_config, file=FileInput(**file) ) + if file_object is None: + continue message["file_content"].append((file_object, file["type"])) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 9a0fb8ed..9282fc28 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -131,7 +131,6 @@ class MemoryAPIService: message: str, config_id: str, storage_type: str = "neo4j", - files: Optional[list]=None, user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -154,8 +153,6 @@ class MemoryAPIService: ResourceNotFoundException: If end_user not found BusinessException: If end_user not in authorized workspace or write fails """ - if files is None: - files = list() logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") # Validate end_user exists and belongs to workspace @@ -175,7 +172,6 @@ class MemoryAPIService: db=self.db, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "", - files=files ) logger.info(f"Memory write successful for end_user: {end_user_id}") diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8255dbbe..effceda7 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -277,8 +277,10 @@ class MemoryPerceptualService: file_message = await multimodel_service.process_files( files=[file] ) - if file_message: - file_message = file_message[0] + if not file_message: + logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}") + return None + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: From 3dc863cabf29d08c02612253f3569a18fe6ffadc Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 15:16:16 +0800 Subject: [PATCH 7/8] feat(memory): add audio_id, vision_id and video_id fields to memory configuration --- api/app/schemas/memory_storage_schema.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index c8abbc46..711b6de9 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -264,6 +264,9 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") + audio_id: Optional[str] = Field(None, description="语音模型ID") + vision_id: Optional[str] = Field(None, description="视觉模型ID") + video_id: Optional[str] = Field(None, description="视频模型ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") enable_llm_dedup_blockwise: Optional[bool] = None From b739b032d940ac77a2d8b269bdc77beb956f4e14 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 15:17:01 +0800 Subject: [PATCH 8/8] fix(workflow): remove edges for unreachable nodes in graph --- api/app/core/workflow/engine/graph_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 674c45d0..c5cf3324 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -390,6 +390,8 @@ class GraphBuilder: for edge in self.edges: source = edge.get("source") target = edge.get("target") + if source not in self.reachable_nodes or target not in self.reachable_nodes: + continue condition = edge.get("condition") edge_type = edge.get("type")