From c17a2dad2d3a7273b1d45783c8291a2d338bbb99 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:22:20 +0800 Subject: [PATCH 001/120] style(memory): Some code style optimizations --- .../controllers/memory_storage_controller.py | 204 +++++++++--------- api/app/core/agent/langchain_agent.py | 6 +- api/app/core/memory/models/graph_models.py | 74 ++++--- .../embedding_generation.py | 3 +- api/app/repositories/neo4j/cypher_queries.py | 21 +- api/app/schemas/memory_storage_schema.py | 83 +++---- api/app/services/memory_storage_service.py | 158 +++++++------- api/app/tasks.py | 39 ++-- 8 files changed, 296 insertions(+), 292 deletions(-) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d91dfc36..d8b39325 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -54,8 +54,8 @@ router = APIRouter( @router.get("/info", response_model=ApiResponse) async def get_storage_info( - storage_id: str, - current_user: User = Depends(get_current_user) + storage_id: str, + current_user: User = Depends(get_current_user) ): """ Example wrapper endpoint - retrieves storage information @@ -75,24 +75,19 @@ async def get_storage_info( return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) - - - - - -@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 +@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 def create_config( - payload: ConfigParamsCreate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), + payload: ConfigParamsCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") try: # 将 workspace_id 注入到 payload 中(保持为 UUID 类型) @@ -107,9 +102,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {err_str}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) @@ -119,9 +116,11 @@ def create_config( api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") lang = get_language_from_header(x_language_type) if lang == "en": - msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", + f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") else: - msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", + f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) @@ -129,10 +128,10 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( - config_id: UUID|int, - force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """删除记忆配置(带终端用户保护) @@ -145,24 +144,24 @@ def delete_config( force: 设置为 true 可强制删除(即使有终端用户正在使用) """ workspace_id = current_user.current_workspace_id - config_id=resolve_config_id(config_id, db) + config_id = resolve_config_id(config_id, db) # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"config_id={config_id}, force={force}" ) - + try: # 使用带保护的删除服务 from app.services.memory_config_service import MemoryConfigService - + config_service = MemoryConfigService(db) result = config_service.delete_config(config_id=config_id, force=force) - + if result["status"] == "error": api_logger.warning( f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" @@ -172,7 +171,7 @@ def delete_config( msg=result["message"], data={"config_id": str(config_id), "is_default": result.get("is_default", False)} ) - + if result["status"] == "warning": api_logger.warning( f"记忆配置正在使用,无法删除: config_id={config_id}, " @@ -186,7 +185,7 @@ def delete_config( "force_required": result["force_required"] } ) - + api_logger.info( f"记忆配置删除成功: config_id={config_id}, " f"affected_users={result['affected_users']}" @@ -195,7 +194,7 @@ def delete_config( msg=result["message"], data={"affected_users": result["affected_users"]} ) - + except Exception as e: api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) @@ -203,9 +202,9 @@ def delete_config( @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc def update_config( - payload: ConfigUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -213,12 +212,13 @@ def update_config( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + # 校验至少有一个字段需要更新 if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") - return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空") - + return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", + "config_name, config_desc, scene_id 均为空") + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -231,9 +231,9 @@ def update_config( @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 def update_config_extracted( - payload: ConfigUpdateExtracted, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigUpdateExtracted, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id payload.config_id = resolve_config_id(payload.config_id, db) @@ -241,7 +241,7 @@ def update_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") try: svc = DataConfigService(db) @@ -256,11 +256,11 @@ def update_config_extracted( # 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config -@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 +@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( - config_id: UUID | int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + config_id: UUID | int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id config_id = resolve_config_id(config_id, db) @@ -268,7 +268,7 @@ def read_config_extracted( if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") try: svc = DataConfigService(db) @@ -278,18 +278,19 @@ def read_config_extracted( api_logger.error(f"Read config extracted failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) -@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 + +@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 def read_all_config( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") try: svc = DataConfigService(db) @@ -303,14 +304,14 @@ def read_all_config( @router.post("/pilot_run", response_model=None) async def pilot_run( - payload: ConfigPilotRun, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + payload: ConfigPilotRun, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> StreamingResponse: # 使用集中化的语言校验 language = get_language_from_header(language_type) - + api_logger.info( f"Pilot run requested: config_id={payload.config_id}, " f"dialogue_text_length={len(payload.dialogue_text)}, " @@ -333,9 +334,9 @@ async def pilot_run( @router.get("/search/kb_type_distribution", response_model=ApiResponse) async def get_kb_type_distribution( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") try: result = await kb_type_distribution(end_user_id) @@ -344,12 +345,12 @@ async def get_kb_type_distribution( api_logger.error(f"KB type distribution failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) - + @router.get("/search/dialogue", response_model=ApiResponse) async def search_dialogues_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") try: result = await search_dialogue(end_user_id) @@ -361,9 +362,9 @@ async def search_dialogues_num( @router.get("/search/chunk", response_model=ApiResponse) async def search_chunks_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") try: result = await search_chunk(end_user_id) @@ -375,9 +376,9 @@ async def search_chunks_num( @router.get("/search/statement", response_model=ApiResponse) async def search_statements_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") try: result = await search_statement(end_user_id) @@ -389,9 +390,9 @@ async def search_statements_num( @router.get("/search/entity", response_model=ApiResponse) async def search_entities_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") try: result = await search_entity(end_user_id) @@ -403,9 +404,9 @@ async def search_entities_num( @router.get("/search", response_model=ApiResponse) async def search_all_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: result = await search_all(end_user_id) @@ -417,9 +418,9 @@ async def search_all_num( @router.get("/search/detials", response_model=ApiResponse) async def search_entities_detials( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search details requested for end_user_id: {end_user_id}") try: result = await search_detials(end_user_id) @@ -431,9 +432,9 @@ async def search_entities_detials( @router.get("/search/edges", response_model=ApiResponse) async def search_entity_edges( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: + end_user_id: Optional[str] = None, + current_user: User = Depends(get_current_user), +) -> dict: api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") try: result = await search_edges(end_user_id) @@ -443,14 +444,12 @@ async def search_entity_edges( return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) - - @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) async def get_hot_memory_tags_api( - limit: int = 10, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - ) -> dict: + limit: int = 10, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> dict: """ 获取热门记忆标签(带Redis缓存) @@ -461,18 +460,18 @@ async def get_hot_memory_tags_api( - 缓存未命中:~600-800ms(取决于LLM速度) """ workspace_id = current_user.current_workspace_id - + # 构建缓存键 cache_key = f"hot_memory_tags:{workspace_id}:{limit}" - + api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") - + try: # 尝试从Redis缓存获取 import json from app.aioRedis import aio_redis_get, aio_redis_set - + cached_result = await aio_redis_get(cache_key) if cached_result: api_logger.info(f"Cache hit for key: {cache_key}") @@ -481,11 +480,11 @@ async def get_hot_memory_tags_api( return success(data=data, msg="查询成功(缓存)") except json.JSONDecodeError: api_logger.warning(f"Failed to parse cached data, will refresh") - + # 缓存未命中,执行查询 api_logger.info(f"Cache miss for key: {cache_key}, executing query") result = await analytics_hot_memory_tags(db, current_user, limit) - + # 写入缓存(过期时间:5分钟) # 注意:result是列表,需要转换为JSON字符串 try: @@ -495,9 +494,9 @@ async def get_hot_memory_tags_api( except Exception as cache_error: # 缓存写入失败不影响主流程 api_logger.warning(f"Failed to cache result: {str(cache_error)}") - + return success(data=result, msg="查询成功") - + except Exception as e: api_logger.error(f"Hot memory tags failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) @@ -505,8 +504,8 @@ async def get_hot_memory_tags_api( @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) async def clear_hot_memory_tags_cache( - current_user: User = Depends(get_current_user), - ) -> dict: + current_user: User = Depends(get_current_user), +) -> dict: """ 清除热门标签缓存 @@ -516,12 +515,12 @@ async def clear_hot_memory_tags_cache( - 数据更新后立即生效 """ workspace_id = current_user.current_workspace_id - + api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") - + try: from app.aioRedis import aio_redis_delete - + # 清除所有limit的缓存(常见的limit值) cleared_count = 0 for limit in [5, 10, 15, 20, 30, 50]: @@ -530,12 +529,12 @@ async def clear_hot_memory_tags_cache( if result: cleared_count += 1 api_logger.info(f"Cleared cache for key: {cache_key}") - + return success( - data={"cleared_count": cleared_count}, + data={"cleared_count": cleared_count}, msg=f"成功清除 {cleared_count} 个缓存" ) - + except Exception as e: api_logger.error(f"Clear cache failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) @@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( - current_user: User = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> dict: workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") @@ -553,4 +552,3 @@ async def get_recent_activity_stats_api( except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 88b6371c..464a668a 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -598,8 +598,10 @@ class LangChainAgent: for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", - 0) if response_meta else 0 + total_tokens = response_meta.get("token_usage", {}).get( + "total_tokens", + 0 + ) if response_meta else 0 yield total_tokens break if memory_flag: diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 1880b9ab..5a2d8c2e 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -44,21 +44,21 @@ def parse_historical_datetime(v): """ if v is None: return v - + # 处理 Neo4j DateTime 对象 if hasattr(v, 'to_native'): return v.to_native() - + # 处理 Python datetime 对象 if isinstance(v, datetime): return v - + if isinstance(v, str): # 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 支持1-4位年份 pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' match = re.match(pattern, v) - + if match: try: year = int(match.group(1)) @@ -68,31 +68,31 @@ def parse_historical_datetime(v): minute = int(match.group(5)) if match.group(5) else 0 second = int(match.group(6)) if match.group(6) else 0 microsecond = 0 - + # 处理微秒 if match.group(7): # 补齐或截断到6位 us_str = match.group(7).ljust(6, '0')[:6] microsecond = int(us_str) - + # 处理时区 tzinfo = None if 'Z' in v or match.group(8): tzinfo = timezone.utc - + # 创建 datetime 对象 return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) - + except (ValueError, OverflowError): # 日期值无效(如月份13、日期32等) return None - + # 如果不匹配模式,尝试使用 fromisoformat(用于标准格式) try: return datetime.fromisoformat(v.replace('Z', '+00:00')) except Exception: return None - + return v @@ -167,7 +167,7 @@ class EntityEntityEdge(Edge): source_statement_id: str = Field(..., description="Statement where this relationship was extracted") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): @@ -206,7 +206,8 @@ class DialogueNode(Node): ref_id: str = Field(..., description="Reference identifier of the dialog") content: str = Field(..., description="Dialogue content") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this dialogue (integer or string)") class StatementNode(Node): @@ -241,17 +242,17 @@ class StatementNode(Node): chunk_id: str = Field(..., description="ID of the parent chunk") stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") - + # Speaker identification speaker: Optional[str] = Field( None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" ) - + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( - None, - ge=0.0, + None, + ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0 (displayed on node)" ) @@ -264,25 +265,26 @@ class StatementNode(Node): description="Emotion subject: self/other/object" ) emotion_type: Optional[str] = Field( - None, + None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral" ) emotion_keywords: Optional[List[str]] = Field( default_factory=list, description="Emotion keywords list, max 3 items" ) - + # Temporal fields temporal_info: TemporalInfo = Field(..., description="Temporal information") valid_at: Optional[datetime] = Field(None, description="Temporal validity start") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - + # Embedding and other fields statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this statement (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -309,13 +311,13 @@ class StatementNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): """使用通用的历史日期解析函数""" return parse_historical_datetime(v) - + @field_validator('emotion_type', mode='before') @classmethod def validate_emotion_type(cls, v): @@ -326,7 +328,7 @@ class StatementNode(Node): if v not in valid_types: raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") return v - + @field_validator('emotion_subject', mode='before') @classmethod def validate_emotion_subject(cls, v): @@ -337,7 +339,7 @@ class StatementNode(Node): if v not in valid_subjects: raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") return v - + @field_validator('emotion_keywords', mode='before') @classmethod def validate_emotion_keywords(cls, v): @@ -405,19 +407,20 @@ class ExtractedEntityNode(Node): entity_type: str = Field(..., description="Type of the entity") description: str = Field(..., description="Entity description") example: str = Field( - default="", + default="", description="A concise example (around 20 characters) to help understand the entity" ) aliases: List[str] = Field( - default_factory=list, + default_factory=list, description="Entity aliases - alternative names for this entity" ) name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary: str = Field(default="", description="Summary of the fact about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this entity (integer or string)") + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, @@ -444,16 +447,16 @@ class ExtractedEntityNode(Node): ge=0, description="Total number of times this node has been accessed" ) - + # Explicit Memory Classification is_explicit_memory: bool = Field( default=False, description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" ) - + @field_validator('aliases', mode='before') @classmethod - def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 + def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 """Validate and clean aliases field using utility function. This validator ensures that the aliases field is always a valid list of strings. @@ -507,8 +510,9 @@ class MemorySummaryNode(Node): memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") - + config_id: Optional[int | str] = Field(None, + description="Configuration ID used to process this summary (integer or string)") + # ACT-R Forgetting Engine Properties original_statement_id: Optional[str] = Field( None, @@ -522,7 +526,7 @@ class MemorySummaryNode(Node): None, description="Timestamp when the nodes were merged" ) - + # ACT-R Memory Activation Properties importance_score: float = Field( default=0.5, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py index 72f3641e..a0bccc25 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py @@ -227,7 +227,8 @@ class EmbeddingGenerator: # 打印前几个嵌入向量的维度 for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") + print(f"实体 '{entity_texts[i]}' " + f"嵌入向量维度: {len(embeddings[i])}") # 将嵌入向量赋值给实体 for ent, emb in zip(entity_refs, embeddings): diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 5b2e5f1e..0ac7dcb1 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -709,7 +709,6 @@ SET r.end_user_id = e.end_user_id, RETURN elementId(r) AS uuid """ - # Entity Merge Query MERGE_ENTITIES = """ MATCH (canonical:ExtractedEntity {id: $canonical_id}) @@ -829,9 +828,8 @@ neo4j_query_all = """ other as entity2 """ - '''针对当前节点下扩长的句子,实体和总结''' -Memory_Timeline_ExtractedEntity=""" +Memory_Timeline_ExtractedEntity = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) = $id AND (ms:ExtractedEntity OR ms:MemorySummary) @@ -869,7 +867,7 @@ RETURN """ -Memory_Timeline_MemorySummary=""" +Memory_Timeline_MemorySummary = """ MATCH (n)-[r1]-(e)-[r2]-(ms) WHERE elementId(n) =$id AND (ms:MemorySummary OR ms:ExtractedEntity) @@ -904,7 +902,7 @@ RETURN } ) AS statement; """ -Memory_Timeline_Statement=""" +Memory_Timeline_Statement = """ MATCH (n) WHERE elementId(n) = $id @@ -947,7 +945,7 @@ RETURN """ '''针对当前节点,主要获取更加完整的句子节点''' -Memory_Space_Emotion_Statement=""" +Memory_Space_Emotion_Statement = """ MATCH (n) WHERE elementId(n) = $id RETURN @@ -957,7 +955,7 @@ RETURN n.statement AS statement; """ -Memory_Space_Emotion_MemorySummary=""" +Memory_Space_Emotion_MemorySummary = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -970,7 +968,7 @@ RETURN DISTINCT e.emotion_type AS emotion_type, e.statement AS statement; """ -Memory_Space_Emotion_ExtractedEntity=""" +Memory_Space_Emotion_ExtractedEntity = """ MATCH (n)-[]-(e) WHERE elementId(n) = $id AND EXISTS { @@ -985,18 +983,18 @@ RETURN DISTINCT '''获取实体''' -Memory_Space_User=""" +Memory_Space_User = """ MATCH (n)-[r]->(m) WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ -Memory_Space_Entity=""" +Memory_Space_Entity = """ MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN DISTINCT m.name as name,m.end_user_id as end_user_id """ -Memory_Space_Associative=""" +Memory_Space_Associative = """ MATCH (u)-[]-(x)-[]-(h) WHERE elementId(u) = $user_id AND elementId(h) = $id @@ -1060,7 +1058,6 @@ Graph_Node_query = """ """ - # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 # ============================================================ diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 046b79e7..c8abbc46 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -8,9 +8,6 @@ import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator - - - # ============================================================================ # 从 json_schema.py 迁移的 Schema # ============================================================================ @@ -58,10 +55,13 @@ class MemoryVerifySchema(BaseModel): class ConflictResultSchema(BaseModel): """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.") + data: List[BaseDataSchema] = Field(..., + description="The conflict memory data. Only contains conflicting records when conflict is True.") conflict: bool = Field(..., description="Whether the memory is in conflict.") - quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") - memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") + quality_assessment: Optional[QualityAssessmentSchema] = Field(None, + description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.") + memory_verify: Optional[MemoryVerifySchema] = Field(None, + description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.") @model_validator(mode="before") def _normalize_data(cls, v): @@ -101,16 +101,19 @@ class ChangeRecordSchema(BaseModel): - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( - ..., + ..., description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) + class ResolvedSchema(BaseModel): """Schema for the resolved memory data in the reflexion_data""" original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") # resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).") - resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") - change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.") + resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, + description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.") + change: Optional[List[ChangeRecordSchema]] = Field(None, + description="List of detailed change records with IDs and field information.") class SingleReflexionResultSchema(BaseModel): @@ -120,9 +123,11 @@ class SingleReflexionResultSchema(BaseModel): resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.") type: str = Field("reflexion_result", description="The type identifier.") + class ReflexionResultSchema(BaseModel): """Schema for the complete reflexion result data - a list of individual conflict resolutions.""" - results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.") + results: List[SingleReflexionResultSchema] = Field(..., + description="List of individual conflict resolution results, grouped by conflict type.") @model_validator(mode="before") def _normalize_resolved(cls, v): @@ -147,9 +152,9 @@ class ReflexionResultSchema(BaseModel): # Composite key identifying a config row class ConfigKey(BaseModel): # 配置参数键模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") - user_id: str = Field("user_id", description="用户标识(字符串)") - apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") + user_id: str | None = Field(default=None, description="用户标识(字符串)") + apply_id: str | None = Field(default=None, description="应用或场景标识(字符串)") # Allowed chunking strategies (extendable later) @@ -228,23 +233,25 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_name: str = Field("配置名称", description="配置名称(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") - + # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") - + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") - + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") + + class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 @@ -255,7 +262,7 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") @@ -322,14 +329,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 # 遗忘引擎配置参数更新模型 - config_id:Union[uuid.UUID, int, str] = None + config_id: Union[uuid.UUID, int, str] = None lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行") model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -364,11 +371,11 @@ def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) def fail( - msg: str, - error_code: str = "ERROR", - data: Optional[Any] = None, - time: Optional[int] = None, - query_preview: Optional[str] = None, + msg: str, + error_code: str = "ERROR", + data: Optional[Any] = None, + time: Optional[int] = None, + query_preview: Optional[str] = None, ) -> ApiResponse: payload = data if query_preview is not None: @@ -387,12 +394,13 @@ def fail( time=time or _now_ms(), ) + class GenerateCacheRequest(BaseModel): """缓存生成请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: Optional[str] = Field( - None, + None, description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" ) @@ -404,7 +412,7 @@ class GenerateCacheRequest(BaseModel): class ForgettingTriggerRequest(BaseModel): """手动触发遗忘周期请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + end_user_id: str = Field(..., description="组ID(即终端用户ID,必填)") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") @@ -413,7 +421,7 @@ class ForgettingTriggerRequest(BaseModel): class ForgettingConfigResponse(BaseModel): """遗忘引擎配置响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)") decay_constant: float = Field(..., description="衰减常数 d") lambda_time: float = Field(..., description="时间衰减参数") @@ -432,7 +440,7 @@ class ForgettingConfigUpdateRequest(BaseModel): """遗忘引擎配置更新请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)") + config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") @@ -448,7 +456,7 @@ class ForgettingConfigUpdateRequest(BaseModel): class ForgettingCycleHistoryPoint(BaseModel): """遗忘周期历史数据点模型(用于趋势图)""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + date: str = Field(..., description="日期(格式: '1/1', '1/2')") merged_count: int = Field(..., description="每日融合节点数") average_activation: Optional[float] = Field(None, description="平均激活值") @@ -459,7 +467,7 @@ class ForgettingCycleHistoryPoint(BaseModel): class PendingForgettingNode(BaseModel): """待遗忘节点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + node_id: str = Field(..., description="节点ID") node_type: str = Field(..., description="节点类型:statement/entity/summary") content_summary: str = Field(..., description="内容摘要") @@ -472,7 +480,8 @@ class ForgettingStatsResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="forbid") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") node_distribution: Dict[str, int] = Field(..., description="节点类型分布") - recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") + recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., + description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") @@ -480,7 +489,7 @@ class ForgettingStatsResponse(BaseModel): class ForgettingReportResponse(BaseModel): """遗忘周期报告响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + merged_count: int = Field(..., description="融合的节点对数量") nodes_before: int = Field(..., description="遗忘前的节点总数") nodes_after: int = Field(..., description="遗忘后的节点总数") @@ -495,7 +504,7 @@ class ForgettingReportResponse(BaseModel): class ForgettingCurvePoint(BaseModel): """遗忘曲线数据点模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + day: int = Field(..., description="天数") activation: float = Field(..., description="激活值") retention_rate: float = Field(..., description="保持率(与激活值相同)") @@ -504,7 +513,7 @@ class ForgettingCurvePoint(BaseModel): class ForgettingCurveRequest(BaseModel): """遗忘曲线请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)") @@ -513,6 +522,6 @@ class ForgettingCurveRequest(BaseModel): class ForgettingCurveResponse(BaseModel): """遗忘曲线响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - + curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") config: Dict[str, Any] = Field(..., description="使用的配置参数") diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 6e7c1ad4..264ae4df 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -11,9 +11,11 @@ import time from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional +from dotenv import load_dotenv +from sqlalchemy.orm import Session + from app.core.logging_config import get_config_logger, get_logger from app.core.memory.analytics.hot_memory_tags import ( - get_hot_memory_tags, get_raw_tags_from_db, filter_tags_with_llm, ) @@ -32,8 +34,6 @@ from app.schemas.memory_storage_schema import ( ) from app.services.memory_config_service import MemoryConfigService from app.utils.sse_utils import format_sse_message -from dotenv import load_dotenv -from sqlalchemy.orm import Session logger = get_logger(__name__) config_logger = get_config_logger() @@ -45,10 +45,10 @@ _neo4j_connector = Neo4jConnector() class MemoryStorageService: """Service for memory storage operations""" - + def __init__(self): logger.info("MemoryStorageService initialized") - + async def get_storage_info(self) -> dict: """ Example wrapper method - retrieves storage information @@ -59,17 +59,17 @@ class MemoryStorageService: Storage information dictionary """ logger.info("Getting storage info ") - + # Empty wrapper - implement your logic here result = { "status": "active", "message": "This is an example wrapper" } - - return result - -class DataConfigService: # 数据配置服务类(PostgreSQL) + return result + + +class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. 使用 SQLAlchemy ORM 进行数据库操作。 @@ -114,7 +114,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return data_list # --- Create --- - def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) # 业务层检查同一工作空间下是否已存在同名配置 if params.workspace_id and params.config_name: from app.models.memory_config_model import MemoryConfig @@ -183,20 +183,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return None # --- Delete --- - def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) + def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) if not success: raise ValueError("未找到配置") return {"affected": 1} # --- Update --- - def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 + def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 config = MemoryConfigRepository.update(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} - def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 + def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 config = MemoryConfigRepository.update_extracted(self.db, update) if not config: raise ValueError("未找到配置") @@ -207,14 +207,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # --- Read --- - def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 + def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) if not result: raise ValueError("未找到配置") return result # --- Read All --- - def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 + def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) # 检查并修正 pruning_scene 与 scene_name 不一致的记录 @@ -241,11 +241,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) except (ValueError, TypeError): config_id_old = None - if config_id_old: - memory_config=config_id_old + memory_config = config_id_old else: - memory_config=config.config_id + memory_config = config.config_id config_dict = { "config_id": memory_config, "config_name": config.config_name, @@ -289,7 +288,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 return self._convert_timestamps_to_format(data_list) - async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]: """ 流式执行试运行,产生 SSE 格式的进度事件 @@ -311,14 +309,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) """ from pathlib import Path project_root = str(Path(__file__).resolve().parents[2]) - + try: # 发出初始进度事件 yield format_sse_message("starting", { "message": "开始试运行...", "time": int(time.time() * 1000) }) - + # 步骤 1: 配置加载和验证(数据库优先) payload_cid = str(getattr(payload, "config_id", "") or "").strip() cid: Optional[str] = payload_cid if payload_cid else None @@ -344,27 +342,28 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # 关联了本体场景,优先使用 custom_text if hasattr(payload, 'custom_text') and payload.custom_text: dialogue_text = payload.custom_text.strip() - logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") + logger.info( + f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}") else: # 如果没有提供 custom_text,回退到 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") + logger.info( + f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}") else: # 没有关联本体场景,使用 dialogue_text dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}") - + # 验证最终使用的文本不为空 if not dialogue_text: raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)") - - logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") + logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}") # 步骤 2: 创建进度回调函数捕获管线进度 # 使用队列在回调和生成器之间传递进度事件 progress_queue: asyncio.Queue = asyncio.Queue() - + async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None: """ 进度回调函数,将进度事件放入队列 @@ -375,14 +374,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) data: 可选的结果数据(用于传递节点执行结果) """ await progress_queue.put((stage, message, data)) - + # 步骤 3: 在后台任务中执行管线 async def run_pipeline(): """在后台执行管线并捕获异常""" try: from app.services.pilot_run_service import run_pilot_extraction - - logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") + + logger.info( + f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}") await run_pilot_extraction( memory_config=memory_config, dialogue_text=dialogue_text, @@ -391,60 +391,60 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) language=language, ) logger.info("[PILOT_RUN_STREAM] pipeline_main completed") - + # 标记管线完成 await progress_queue.put(("__PIPELINE_COMPLETE__", "", None)) except Exception as e: # 将异常放入队列 await progress_queue.put(("__PIPELINE_ERROR__", str(e), None)) - + # 启动后台任务 pipeline_task = asyncio.create_task(run_pipeline()) - + # 步骤 4: 从队列中读取进度事件并发出 while True: try: # 等待进度事件,设置超时以检测客户端断开 stage, message, data = await asyncio.wait_for( - progress_queue.get(), + progress_queue.get(), timeout=0.5 ) - + # 检查特殊标记 if stage == "__PIPELINE_COMPLETE__": break elif stage == "__PIPELINE_ERROR__": raise RuntimeError(message) - + # 构建进度事件数据 progress_data = { "message": message, "time": int(time.time() * 1000) } - + # 如果有结果数据,添加到事件中 if data: progress_data["data"] = data - + # 发出进度事件,使用 stage 作为事件类型 yield format_sse_message(stage, progress_data) - + except TimeoutError: # 超时,继续等待(这允许检测客户端断开) continue - + # 等待管线任务完成 await pipeline_task - + # 步骤 5: 读取提取结果 from app.core.config import settings result_path = settings.get_memory_output_path("extracted_result.json") if not os.path.isfile(result_path): raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - + with open(result_path, "r", encoding="utf-8") as rf: extracted_result = json.load(rf) - + # 步骤 6: 计算本体覆盖率并合并到结果中 result_data = { "config_id": cid, @@ -460,15 +460,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) result_data["ontology_coverage"] = ontology_coverage except Exception as cov_err: logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) - + yield format_sse_message("result", result_data) - + # 步骤 7: 发出完成事件 yield format_sse_message("done", { "message": "试运行完成", "time": int(time.time() * 1000) }) - + except asyncio.CancelledError: # 客户端断开连接 logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming") @@ -483,11 +483,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "time": int(time.time() * 1000) }) - async def _compute_ontology_coverage( - self, - extracted_result: Dict[str, Any], - memory_config, + self, + extracted_result: Dict[str, Any], + memory_config, ) -> Optional[Dict[str, Any]]: """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 @@ -580,8 +579,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # Ensure env for connector (e.g., NEO4J_PASSWORD) -load_dotenv() -_neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: @@ -664,7 +661,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A # 检查结果是否为空或长度不足 if not result or len(result) < 4: data = { - "total": 0, + "total": 0, "distribution": [ {"type": "dialogue", "count": 0}, {"type": "chunk", "count": 0}, @@ -701,10 +698,11 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] ) return result + async def analytics_hot_memory_tags( - db: Session, - current_user: User, - limit: int = 10 + db: Session, + current_user: User, + limit: int = 10 ) -> List[Dict[str, Any]]: """ 获取热门记忆标签,按数量排序并返回前N个 @@ -721,27 +719,27 @@ async def analytics_hot_memory_tags( from app.services.memory_dashboard_service import get_workspace_end_users # 使用 asyncio.to_thread 避免阻塞事件循环 end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) - + if not end_users: return [] - + # 步骤1: 收集所有用户的原始标签(不调用LLM) connector = Neo4jConnector() try: all_raw_tags = [] for end_user in end_users: raw_tags = await get_raw_tags_from_db( - connector, - str(end_user.id), - limit=raw_limit, + connector, + str(end_user.id), + limit=raw_limit, by_user=False ) if raw_tags: all_raw_tags.extend(raw_tags) - + if not all_raw_tags: return [] - + # 步骤2: 聚合相同标签的频率 tag_frequency_map = {} for tag_name, frequency in all_raw_tags: @@ -749,36 +747,36 @@ async def analytics_hot_memory_tags( tag_frequency_map[tag_name] += frequency else: tag_frequency_map[tag_name] = frequency - + # 步骤3: 按频率降序排序,取前raw_limit个 sorted_tags = sorted( - tag_frequency_map.items(), - key=lambda x: x[1], + tag_frequency_map.items(), + key=lambda x: x[1], reverse=True )[:raw_limit] - + if not sorted_tags: return [] - + # 步骤4: 只调用一次LLM进行筛选 tag_names = [tag for tag, _ in sorted_tags] - + # 使用第一个用户的end_user_id来获取LLM配置 # 因为同一工作空间下的用户应该使用相同的配置 first_end_user_id = str(end_users[0].id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) - + # 步骤5: 根据LLM筛选结果构建最终列表(保留频率) final_tags = [] for tag, freq in sorted_tags: if tag in filtered_tag_names: final_tags.append((tag, freq)) - + # 步骤6: 只返回前limit个 top_tags = final_tags[:limit] - + return [{"name": t, "frequency": f} for t, f in top_tags] - + finally: await connector.close() @@ -815,11 +813,11 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> source = "log" total = ( - stats.get("chunk_count", 0) - + stats.get("statements_count", 0) - + stats.get("triplet_entities_count", 0) - + stats.get("triplet_relations_count", 0) - + stats.get("temporal_count", 0) + stats.get("chunk_count", 0) + + stats.get("statements_count", 0) + + stats.get("triplet_entities_count", 0) + + stats.get("triplet_relations_count", 0) + + stats.get("temporal_count", 0) ) # 计算"最新一次活动多久前"(仅日志来源时有效) @@ -845,5 +843,3 @@ async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data - - diff --git a/api/app/tasks.py b/api/app/tasks.py index f5258330..c37e564e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1073,9 +1073,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, - user_rag_memory_id: str, - language: str = "zh") -> Dict[str, Any]: +def write_message_task( + self, + end_user_id: str, + message: list[dict], + config_id: str | int, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" +) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) @@ -1105,14 +1111,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100 * '-') - print(actual_config_id) - print(100 * '-') - logger.info( - f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} " + f"(type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error( - f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} " + f"(type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1151,8 +1154,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - logger.info( - f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + logger.info(f"[CELERY WRITE] Task completed successfully " + f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: @@ -1167,7 +1170,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s ) except Exception as _e: logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") - return { "status": "SUCCESS", "result": result, @@ -2672,7 +2674,7 @@ def write_perceptual_memory( ignore_result=False, max_retries=0, acks_late=False, - time_limit=7200, # 2小时硬超时 + time_limit=7200, # 2小时硬超时 soft_time_limit=6900, ) def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: @@ -2749,7 +2751,8 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s llm_model_id=llm_model_id, ) - logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + logger.info( + f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") await engine.full_clustering(end_user_id) initialized += 1 logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") @@ -2772,12 +2775,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time From dce7206c447893bf89d7d3efcf4046c3139b22d6 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:28:21 +0800 Subject: [PATCH 002/120] fix(celery, rag): unify rag_write return type and remove deprecated downstream calls - Unify the return type of `rag_write` in Celery tasks for consistency. - Remove two deprecated downstream API calls to avoid obsolete dependencies. --- .../controllers/memory_agent_controller.py | 272 +++++++++--------- .../repositories/memory_config_repository.py | 48 +--- api/app/services/memory_agent_service.py | 20 +- api/app/services/memory_konwledges_server.py | 19 +- 4 files changed, 169 insertions(+), 190 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index e3d2bf92..aa4d48e3 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -118,142 +118,142 @@ async def download_log( return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) -@router.post("/writer_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Write service endpoint - processes write operations synchronously - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Response with write operation status - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - - # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id - if storage_type == 'rag': - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - else: - api_logger.warning( - f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") - storage_type = 'neo4j' - else: - api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") - storage_type = 'neo4j' - - api_logger.info( - f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") - try: - messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.end_user_id, - messages_list, - config_id, - db, - storage_type, - user_rag_memory_id, - language - ) - - return success(data=result, msg="写入成功") - except BaseException as e: - # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup - if hasattr(e, 'exceptions'): - error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] - detailed_error = "; ".join(error_messages) - api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) - api_logger.error(f"Write operation error: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/writer_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Async write service endpoint - enqueues write processing to Celery - - Args: - user_input: Write request containing message and end_user_id - language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 - - Returns: - Task ID for tracking async operation - Use GET /memory/write_result/{task_id} to check task status and get result - """ - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - - config_id = user_input.config_id - workspace_id = current_user.current_workspace_id - api_logger.info( - f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - - task = celery_app.send_task( - "app.core.memory.agent.write_message", - args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] - ) - api_logger.info(f"Write task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="写入任务已提交") - except Exception as e: - api_logger.error(f"Async write operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# @router.post("/writer_service", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Write service endpoint - processes write operations synchronously +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Response with write operation status +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# +# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id +# if storage_type == 'rag': +# if workspace_id: +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: +# user_rag_memory_id = str(knowledge.id) +# else: +# api_logger.warning( +# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") +# storage_type = 'neo4j' +# else: +# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") +# storage_type = 'neo4j' +# +# api_logger.info( +# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") +# try: +# messages_list = memory_agent_service.get_messages_list(user_input) +# result = await memory_agent_service.write_memory( +# user_input.end_user_id, +# messages_list, +# config_id, +# db, +# storage_type, +# user_rag_memory_id, +# language +# ) +# +# return success(data=result, msg="写入成功") +# except BaseException as e: +# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup +# if hasattr(e, 'exceptions'): +# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] +# detailed_error = "; ".join(error_messages) +# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) +# api_logger.error(f"Write operation error: {str(e)}", exc_info=True) +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) +# +# +# @router.post("/writer_service_async", response_model=ApiResponse) +# @cur_workspace_access_guard() +# async def write_server_async( +# user_input: Write_UserInput, +# language_type: str = Header(default=None, alias="X-Language-Type"), +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user) +# ): +# """ +# Async write service endpoint - enqueues write processing to Celery +# +# Args: +# user_input: Write request containing message and end_user_id +# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 +# +# Returns: +# Task ID for tracking async operation +# Use GET /memory/write_result/{task_id} to check task status and get result +# """ +# # 使用集中化的语言校验 +# language = get_language_from_header(language_type) +# +# config_id = user_input.config_id +# workspace_id = current_user.current_workspace_id +# api_logger.info( +# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") +# +# # 获取 storage_type,如果为 None 则使用默认值 +# storage_type = workspace_service.get_workspace_storage_type( +# db=db, +# workspace_id=workspace_id, +# user=current_user +# ) +# if storage_type is None: storage_type = 'neo4j' +# user_rag_memory_id = '' +# if workspace_id: +# +# knowledge = knowledge_repository.get_knowledge_by_name( +# db=db, +# name="USER_RAG_MERORY", +# workspace_id=workspace_id +# ) +# if knowledge: user_rag_memory_id = str(knowledge.id) +# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") +# try: +# # 获取标准化的消息列表 +# messages_list = memory_agent_service.get_messages_list(user_input) +# +# task = celery_app.send_task( +# "app.core.memory.agent.write_message", +# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] +# ) +# api_logger.info(f"Write task queued: {task.id}") +# +# return success(data={"task_id": task.id}, msg="写入任务已提交") +# except Exception as e: +# api_logger.error(f"Async write operation failed: {str(e)}") +# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) @router.post("/read_service", response_model=ApiResponse) diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 22f13449..5c2f81a7 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -309,57 +309,21 @@ class MemoryConfigRepository: Returns: Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None - - Raises: - ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"更新萃取配置: config_id={update.config_id}") try: - db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() + stmt = select(MemoryConfig).where(MemoryConfig.config_id == update.config_id) + db_config = db.execute(stmt).scalar_one_or_none() if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - # 更新字段映射 - field_mapping = { - # 模型选择 - "llm_id": "llm_id", - "embedding_id": "embedding_id", - "rerank_id": "rerank_id", - # 记忆萃取引擎 - "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise", - "enable_llm_disambiguation": "enable_llm_disambiguation", - "deep_retrieval": "deep_retrieval", - "t_type_strict": "t_type_strict", - "t_name_strict": "t_name_strict", - "t_overall": "t_overall", - "state": "state", - "chunker_strategy": "chunker_strategy", - # 句子提取 - "statement_granularity": "statement_granularity", - "include_dialogue_context": "include_dialogue_context", - "max_context": "max_context", - # 剪枝配置 - "pruning_enabled": "pruning_enabled", - "pruning_scene": "pruning_scene", - "pruning_threshold": "pruning_threshold", - # 自我反思配置 - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - } + update_data = update.model_dump(exclude_unset=True) + update_data.pop("config_id", None) - has_update = False - for api_field, db_field in field_mapping.items(): - value = getattr(update, api_field, None) - if value is not None: - setattr(db_config, db_field, value) - has_update = True - - if not has_update: - raise ValueError("No fields to update") + for field, value in update_data.items(): + setattr(db_config, field, value) db.commit() db.refresh(db_config) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index af9a04e2..514cb12f 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -267,8 +267,16 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, - db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory( + self, + end_user_id: str, + messages: list[dict], + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str, + language: str = "zh" + ) -> str: """ Process write operation with config_id @@ -297,8 +305,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError( - f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. " + f"Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -338,8 +346,8 @@ class MemoryAgentService: if storage_type == "rag": # For RAG storage, convert messages to single string message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result + await write_rag(end_user_id, message_text, user_rag_memory_id) + return "success" else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index b8961d33..523adadb 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -341,7 +341,7 @@ async def memory_konwledges_up( ) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") + return db_document async def create_document_chunk( @@ -350,7 +350,7 @@ async def create_document_chunk( create_data: ChunkCreate, db: Session, current_user: User -): +) -> DocumentChunk: """ 创建文档块 @@ -439,10 +439,10 @@ async def create_document_chunk( db_document.chunk_num += 1 db.commit() - return success(data=chunk, msg="文档块创建成功") + return chunk -async def write_rag(end_user_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk: """ 将消息写入 RAG 知识库 @@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id): document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======', document) api_logger.info(f"查找文档结果: document_id={document}") + create_chunks = ChunkCreate(content=message) if document is not None: # 文档已存在,直接添加新块 api_logger.info(f"文档已存在,添加新块: document_id={document}") - create_chunks = ChunkCreate(content=message) result = await create_document_chunk( kb_id=kb_uuid, document_id=uuid.UUID(document), @@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: # 文档不存在,创建新文档 api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") - result = await memory_konwledges_up( + document = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, create_data=create_data, db=db, current_user=current_user ) + result = await create_document_chunk( + kb_id=kb_uuid, + document_id=document.id, + create_data=create_chunks, + db=db, + current_user=current_user + ) # 重新查询刚创建的文档ID new_document_id = find_document_id_by_kb_and_filename( db=db, From 31085ed678ef29d77ba9a7feaa59338a7201d195 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:31:17 +0800 Subject: [PATCH 003/120] fix(workflow): fix memory write behavior in RAG workspace --- .../core/workflow/engine/runtime_schema.py | 14 ++++- api/app/core/workflow/engine/state_manager.py | 9 ++- api/app/core/workflow/engine/variable_pool.py | 10 ++- api/app/core/workflow/executor.py | 20 ++++-- api/app/core/workflow/nodes/memory/node.py | 61 ++++++++++++++++--- api/app/services/workflow_service.py | 29 ++++++++- api/app/services/workspace_service.py | 2 +- 7 files changed, 128 insertions(+), 17 deletions(-) diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index e4bf65af..48eafaa9 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,14 +12,26 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + memory_storage_type: str + user_rag_memory_id: str checkpoint_config: RunnableConfig @classmethod - def create(cls, execution_id: str, workspace_id: str, user_id: str): + def create( + cls, + execution_id: str, + workspace_id: str, + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str + ): return cls( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id, + checkpoint_config=RunnableConfig( configurable={ "thread_id": uuid.uuid4(), diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 0a4a1463..2da0d3a8 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -33,6 +33,8 @@ class WorkflowState(dict): "workspace_id", "user_id", "activate", + "memory_storage_type", + "user_rag_memory_id" }) __optional_keys__ = frozenset({ "error", @@ -62,6 +64,9 @@ class WorkflowState(dict): # node activate status activate: Annotated[dict[str, bool], merge_activate_state] + memory_storage_type: str + user_rag_memory_id: str + class WorkflowStateManager: def create_initial_state( @@ -85,7 +90,9 @@ class WorkflowStateManager: looping=0, activate={ start_node_id: True - } + }, + memory_storage_type=execution_context.memory_storage_type, + user_rag_memory_id=execution_context.user_rag_memory_id ) @staticmethod diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index cf6f4a7b..d4e1b488 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE -from app.core.workflow.variable.variable_objects import T, create_variable_instance +from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable logger = logging.getLogger(__name__) @@ -373,6 +373,14 @@ class VariablePool: def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) + def is_file_variable(self, selector): + variable_struct = self._get_variable_struct(selector) + if isinstance(variable_struct, FileVariable): + return True + elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: + return True + return False + def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c9ed6e65..6a127e96 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -409,7 +409,9 @@ async def execute_workflow( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ) -> dict[str, Any]: """ Execute a workflow (convenience function, non-streaming). @@ -420,6 +422,8 @@ async def execute_workflow( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Returns: dict: Workflow execution result. @@ -427,7 +431,9 @@ async def execute_workflow( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, @@ -441,7 +447,9 @@ async def execute_workflow_stream( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ): """ Execute a workflow in streaming mode (convenience function). @@ -452,6 +460,8 @@ async def execute_workflow_stream( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Yields: dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. @@ -459,7 +469,9 @@ async def execute_workflow_stream( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 1d42e82e..82363056 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -1,3 +1,4 @@ +import re from typing import Any from app.core.workflow.engine.state_manager import WorkflowState @@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read +from app.schemas import FileInput from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task @@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode): search_switch=self.typed_config.search_switch, history=[], db=db, - storage_type="neo4j", - user_rag_memory_id="" + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) @@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} + @staticmethod + def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]: + variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' + variable_pattern = re.compile(variable_pattern_string) + variables = variable_pattern.findall(content) + file_variables = [] + for variable in variables: + if variable_pool.is_file_variable(variable): + file_variables.append(variable) + for var in file_variables: + content = content.replace(var, "") + return file_variables, content + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", variable_pool) @@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] + multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode): }) for message in self.typed_config.messages: + file_variables, content = self._extract_multimodal_memory_variables( + message.content, + variable_pool + ) + file_info = [] + for var in file_variables: + instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var) + if isinstance(instence, FileVariable): + file_info.append(FileInput( + type=instence.value.type, + transfer_method=instence.value.transfer_method, + upload_file_id=instence.value.file_id, + url=instence.value.url, + file_type=instence.value.origin_file_type + ).model_dump()) + elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable: + for file_instence in instence.value: + file_info.append(FileInput( + type=file_instence.value.type, + transfer_method=file_instence.value.transfer_method, + upload_file_id=file_instence.value.file_id, + url=file_instence.value.url, + file_type=file_instence.value.origin_file_type + ).model_dump()) + multimodal_memories.append({ + "role": message.role, + "files": file_info + }) messages.append({ "role": message.role, - "content": self._render_template(message.content, variable_pool) + "content": self._render_template(content, variable_pool) }) write_message_task.delay( - end_user_id, - messages, - str(self.typed_config.config_id), - "neo4j", - "" + end_user_id=end_user_id, + message=messages, + config_id=str(self.typed_config.config_id), + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) return "success" diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 56f34496..db659268 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories import knowledge_repository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService +from app.services.workspace_service import get_workspace_storage_type_without_auth logger = logging.getLogger(__name__) @@ -536,6 +538,25 @@ class WorkflowService: mapped = internal_event return mapped + def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: + storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) + user_rag_memory_id = "" + if storage_type == "rag": + knowledge = knowledge_repository.get_knowledge_by_name( + db=self.db, + name="USER_RAG_MERORY", + workspace_id=workspace_id + ) + if knowledge: + user_rag_memory_id = str(knowledge.id) + else: + logger.warning( + f"No knowledge base named 'USER_RAG_MEMORY' found, " + f"workspace_id: {workspace_id}, will use neo4j storage" + ) + storage_type = 'neo4j' + return storage_type, user_rag_memory_id + # ==================== 工作流执行 ==================== async def run( @@ -603,6 +624,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files message_id = uuid.uuid4() # 更新状态为运行中 @@ -627,7 +649,9 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ) # 更新执行结果 if result.get("status") == "completed": @@ -776,6 +800,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -797,6 +822,8 @@ class WorkflowService: execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index cefb8380..90b5cf65 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -863,7 +863,7 @@ def get_workspace_storage_type( def get_workspace_storage_type_without_auth( db: Session, workspace_id: uuid.UUID, -) -> Optional[str]: +) -> str: """获取工作空间的存储类型(无需权限验证,用于公开分享等场景) Args: From 2ff81ba101b422e600982dcc96fc673bc4684223 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 23 Mar 2026 16:33:25 +0800 Subject: [PATCH 004/120] feat(memory): support perception-aware memory writing in workflow and Neo4j nodes --- .../core/memory/agent/utils/write_tools.py | 65 ++- api/app/core/memory/models/graph_models.py | 18 + .../deduplication/two_stage_dedup.py | 24 +- .../extraction_orchestrator.py | 465 +++++++++--------- .../knowledge_extraction/memory_summary.py | 1 - api/app/core/workflow/engine/variable_pool.py | 4 +- api/app/core/workflow/nodes/base_node.py | 8 +- api/app/core/workflow/nodes/llm/node.py | 14 +- api/app/core/workflow/nodes/memory/node.py | 1 + api/app/models/memory_config_model.py | 3 + .../repositories/memory_config_repository.py | 57 +-- api/app/repositories/neo4j/add_nodes.py | 111 ++++- api/app/repositories/neo4j/cypher_queries.py | 33 ++ api/app/schemas/memory_config_schema.py | 6 + api/app/services/app_chat_service.py | 4 +- api/app/services/draft_run_service.py | 4 +- api/app/services/memory_agent_service.py | 85 ++-- api/app/services/memory_api_service.py | 85 ++-- api/app/services/memory_config_service.py | 194 +++++--- api/app/services/memory_perceptual_service.py | 120 ++++- api/app/services/multimodal_service.py | 31 +- api/app/tasks.py | 6 +- 22 files changed, 820 insertions(+), 519 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..147a0316 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio +import uuid import time from datetime import datetime @@ -13,28 +14,31 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ + memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context +from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ + add_perceptual_dialogue_edges from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig - load_dotenv() logger = get_agent_logger(__name__) async def write( - end_user_id: str, - memory_config: MemoryConfig, - messages: list, - ref_id: str = "wyl20251027", - language: str = "zh", + end_user_id: str, + memory_config: MemoryConfig, + messages: list, + file_content: list[MemoryPerceptualModel], + ref_id: str = "", + language: str = "zh", ) -> None: """ Execute the complete knowledge extraction pipeline. @@ -43,9 +47,12 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - ref_id: Reference ID, defaults to "wyl20251027" + file_content: mutilmodal message list + ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ + if not ref_id: + ref_id = uuid.uuid4().hex # Extract config values embedding_model_id = str(memory_config.embedding_model_id) chunker_strategy = memory_config.chunker_strategy @@ -99,14 +106,14 @@ async def write( if memory_config.scene_id: try: from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene - + with get_db_context() as db: ontology_types = load_ontology_types_for_scene( scene_id=memory_config.scene_id, workspace_id=memory_config.workspace_id, db=db ) - + if ontology_types: logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" @@ -173,7 +180,8 @@ async def write( schedule_clustering_after_write( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + embedding_model_id=str( + memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) break else: @@ -208,9 +216,8 @@ async def write( summaries = await memory_summary_generation( chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language ) - + ms_connector = Neo4jConnector() try: - ms_connector = Neo4jConnector() await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector) finally: @@ -223,6 +230,34 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) + # Step 5: Save perceptual memory to Neo4j + step_start = time.time() + if file_content: + try: + pc_connector = Neo4jConnector() + try: + created_ids = await add_perceptual_nodes( + perceptuals=file_content, + connector=pc_connector, + embedder_client=embedder_client, + ) + # 如果有 ref_id,建立感知记忆与对话的关联 + if ref_id and created_ids: + await add_perceptual_dialogue_edges( + perceptuals=file_content, + dialog_id=ref_id, + connector=pc_connector, + ) + logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") + finally: + try: + await pc_connector.close() + except Exception: + pass + except Exception as e: + logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) + log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) + # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) @@ -251,4 +286,4 @@ async def write( logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file + logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 5a2d8c2e..fb251f1f 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -553,3 +553,21 @@ class MemorySummaryNode(Node): ge=0, description="Total number of times this node has been accessed (reset to 1 on creation)" ) + + +class MutlimodalNode(Node): + """Node representing a multimodal message in the knowledge graph. + + Attributes: + dialog_id: ID of the parent dialog + message_id: ID of the message + metadata: Additional message metadata + embedding: Optional embedding vector for the message + """ + dialog_id: str = Field(..., description="ID of the parent dialog") + message_id: str = Field(..., description="ID of the message") + summary: str = Field(..., description="The text content of the message") + file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") + file_path: List[str] = Field(..., description="List of file paths for multimodal content") + metadata: dict = Field(default_factory=dict, description="Additional message metadata") + embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index f28b8a5f..4b9c5718 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def dedup_layers_and_merge_and_return( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - pipeline_config: ExtractionPipelineConfig, - connector: Optional[Neo4jConnector] = None, - llm_client = None, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + pipeline_config: ExtractionPipelineConfig, + connector: Optional[Neo4jConnector] = None, + llm_client=None, ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return( List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], - dict, # 新增:返回去重详情 + dict ]: """ 执行两层实体去重与融合: diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 00d06f72..6e94a84f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,11 +31,10 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode, + StatementNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList -from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, ) @@ -46,7 +45,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb embedding_generation, generate_entity_embeddings_from_triplets, ) - # 导入各个提取模块 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( StatementExtractor, @@ -90,16 +88,16 @@ class ExtractionOrchestrator: """ def __init__( - self, - llm_client: LLMClient, - embedder_client: OpenAIEmbedderClient, - connector: Neo4jConnector, - config: Optional[ExtractionPipelineConfig] = None, - progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, - embedding_id: Optional[str] = None, - ontology_types: Optional[OntologyTypeList] = None, - enable_general_types: bool = True, - language: str = "zh", + self, + llm_client: LLMClient, + embedder_client: OpenAIEmbedderClient, + connector: Neo4jConnector, + config: Optional[ExtractionPipelineConfig] = None, + progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, + embedding_id: Optional[str] = None, + ontology_types: Optional[OntologyTypeList] = None, + enable_general_types: bool = True, + language: str = "zh", ): """ 初始化流水线编排器 @@ -123,7 +121,7 @@ class ExtractionOrchestrator: self.progress_callback = progress_callback # 保存进度回调函数 self.embedding_id = embedding_id # 保存嵌入模型ID self.language = language # 保存语言配置 - + # 处理本体类型配置 # 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并 # 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合 @@ -146,7 +144,7 @@ class ExtractionOrchestrator: self.ontology_types = ontology_types if not enable_general_types and ontology_types: logger.info("enable_general_types=False,仅使用场景类型") - + # 保存去重消歧的详细记录(内存中的数据结构) self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录 self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录 @@ -157,19 +155,25 @@ class ExtractionOrchestrator: llm_client=llm_client, config=self.config.statement_extraction, ) - self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) + self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types, + language=language) self.temporal_extractor = TemporalExtractor(llm_client=llm_client) logger.info("ExtractionOrchestrator 初始化完成") async def run( - self, - dialog_data_list: List[DialogData], - is_pilot_run: bool = False, - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialog_data_list: List[DialogData], + is_pilot_run: bool = False, + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -202,13 +206,12 @@ class ExtractionOrchestrator: # 步骤 1: 陈述句提取 logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") dialog_data_list = await self._extract_statements(dialog_data_list) - + # 收集陈述句内容和统计数量 all_statements_list = [] for dialog in dialog_data_list: for chunk in dialog.chunks: all_statements_list.extend(chunk.statements) - len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") @@ -220,7 +223,7 @@ class ExtractionOrchestrator: chunk_embedding_maps, dialog_embeddings, ) = await self._parallel_extract_and_embed(dialog_data_list) - + # 收集实体和三元组内容,并统计数量 all_entities_list = [] all_triplets_list = [] @@ -229,10 +232,6 @@ class ExtractionOrchestrator: if triplet_info: all_entities_list.extend(triplet_info.entities) all_triplets_list.extend(triplet_info.triplets) - - len(all_entities_list) - len(all_triplets_list) - sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") @@ -252,9 +251,9 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") - + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 - + ( dialogue_nodes, chunk_nodes, @@ -273,9 +272,9 @@ class ExtractionOrchestrator: logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)") else: logger.info("步骤 6/6: 两阶段去重和消歧") - + # 注意:deduplication 消息已在创建节点和边完成后立即发送 - + result = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -287,8 +286,6 @@ class ExtractionOrchestrator: dialog_data_list, ) - - logger.info(f"知识提取流水线运行完成({mode_str})") return result @@ -297,7 +294,7 @@ class ExtractionOrchestrator: raise async def _extract_statements( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ 从对话中提取陈述句(流式输出版本:边提取边发送进度) @@ -313,7 +310,7 @@ class ExtractionOrchestrator: # 收集所有分块及其元数据 all_chunks = [] chunk_metadata = [] # (dialog_idx, chunk_idx) - + for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): @@ -321,7 +318,7 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") - + # 用于跟踪已完成的分块数量 completed_chunks = 0 total_chunks = len(all_chunks) @@ -332,7 +329,7 @@ class ExtractionOrchestrator: chunk, end_user_id, dialogue_content = chunk_data try: statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) - + # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 completed_chunks += 1 @@ -347,11 +344,11 @@ class ExtractionOrchestrator: "statement_index_in_chunk": idx + 1 } await self.progress_callback( - "knowledge_extraction_result", - f"陈述句提取中 ({completed_chunks}/{total_chunks})", + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", stmt_result ) - + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") @@ -381,7 +378,7 @@ class ExtractionOrchestrator: # 保存陈述句到文件(试运行和正式模式都需要) self.statement_extractor.save_statements(all_statements) - + logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") # 试运行模式下,所有分块提取完成后发送完成事件 @@ -395,7 +392,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _extract_triplets( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取三元组(流式输出版本:边提取边发送进度) @@ -411,7 +408,7 @@ class ExtractionOrchestrator: # 收集所有陈述句及其元数据 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, chunk_content) - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -419,7 +416,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") - + # 用于跟踪已完成的陈述句数量 completed_statements = 0 len(all_statements) @@ -430,11 +427,11 @@ class ExtractionOrchestrator: statement, chunk_content = stmt_data try: triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) - + # 注意:不再发送三元组提取的流式输出 # 三元组提取在后台执行,但不向前端发送详细信息 completed_statements += 1 - + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") @@ -450,7 +447,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 triplet_maps = [{} for _ in dialog_data_list] all_responses = [] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -478,7 +475,7 @@ class ExtractionOrchestrator: return triplet_maps async def _extract_temporal( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取时间信息(流式输出版本:边提取边发送进度) @@ -502,13 +499,13 @@ class ExtractionOrchestrator: temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) temporal_maps.append(temporal_map) return temporal_maps - + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, ref_dates) - + for d_idx, dialog in enumerate(dialog_data_list): # 获取参考日期 ref_dates = {} @@ -517,11 +514,11 @@ class ExtractionOrchestrator: ref_dates['conversation_date'] = dialog.metadata['conversation_date'] if 'publication_date' in dialog.metadata: ref_dates['publication_date'] = dialog.metadata['publication_date'] - + if not ref_dates: from datetime import datetime ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - + for chunk in dialog.chunks: for statement in chunk.statements: # 跳过 ATEMPORAL 类型的陈述句 @@ -531,7 +528,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") - + # 用于跟踪已完成的时间提取数量 completed_temporal = 0 len(all_statements) @@ -542,11 +539,11 @@ class ExtractionOrchestrator: statement, ref_dates = stmt_data try: temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) - + # 注意:不再发送时间提取的流式输出 # 时间提取在后台执行,但不向前端发送详细信息 completed_temporal += 1 - + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") @@ -559,7 +556,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 temporal_maps = [{} for _ in dialog_data_list] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -585,7 +582,7 @@ class ExtractionOrchestrator: return temporal_maps async def _extract_emotions( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) @@ -601,36 +598,36 @@ class ExtractionOrchestrator: # 收集所有陈述句及其配置 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id) - + # 获取第一个对话的config_id来加载配置 config_id = None if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): config_id = dialog_data_list[0].config_id - + # 加载MemoryConfig memory_config = None if config_id: try: from app.db import SessionLocal from app.repositories.memory_config_repository import MemoryConfigRepository - + db = SessionLocal() try: memory_config = MemoryConfigRepository.get_by_id(db, config_id) finally: db.close() - + if memory_config and not memory_config.emotion_enabled: logger.info("情绪提取已在配置中禁用,跳过情绪提取") return [{} for _ in dialog_data_list] - + except Exception as e: logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取") return [{} for _ in dialog_data_list] else: logger.info("未找到config_id,跳过情绪提取") return [{} for _ in dialog_data_list] - + # 如果配置未启用情绪提取,直接返回空映射 if not memory_config or not memory_config.emotion_enabled: logger.info("情绪提取未启用,跳过") @@ -639,7 +636,7 @@ class ExtractionOrchestrator: # 收集所有陈述句(只收集 speaker 为 "user" 的) total_statements = 0 filtered_statements = 0 - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -655,12 +652,12 @@ class ExtractionOrchestrator: # 初始化情绪提取服务 # 如果 emotion_model_id 为空,回退到工作空间默认 LLM from app.services.emotion_extraction_service import EmotionExtractionService - + emotion_model_id = memory_config.emotion_model_id if not emotion_model_id and memory_config.workspace_id: from app.repositories.workspace_repository import get_workspace_models_configs from app.db import SessionLocal - + db = SessionLocal() try: workspace_models = get_workspace_models_configs(db, memory_config.workspace_id) @@ -669,7 +666,7 @@ class ExtractionOrchestrator: logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}") finally: db.close() - + emotion_service = EmotionExtractionService( llm_id=emotion_model_id if emotion_model_id else None ) @@ -689,7 +686,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 emotion_maps = [{} for _ in dialog_data_list] successful_extractions = 0 - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -706,7 +703,7 @@ class ExtractionOrchestrator: return emotion_maps async def _parallel_extract_and_embed( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[Dict[str, Any]], List[Dict[str, Any]], @@ -757,7 +754,7 @@ class ExtractionOrchestrator: triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - + if isinstance(results[3], Exception): logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] @@ -777,7 +774,7 @@ class ExtractionOrchestrator: ) async def _generate_basic_embeddings( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: """ 生成基础嵌入向量(陈述句、分块、对话) @@ -810,7 +807,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") raise ValueError("embedding_id is required but was not provided") - + # 只生成陈述句、分块和对话的嵌入(不包括实体) statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation( dialog_data_list, self.embedding_id @@ -836,7 +833,7 @@ class ExtractionOrchestrator: ) async def _generate_entity_embeddings( - self, triplet_maps: List[Dict[str, Any]] + self, triplet_maps: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 生成实体嵌入向量 @@ -861,7 +858,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") return triplet_maps - + # 生成实体嵌入 updated_triplet_maps = await generate_entity_embeddings_from_triplets( triplet_maps, self.embedding_id @@ -874,17 +871,15 @@ class ExtractionOrchestrator: logger.error(f"实体嵌入生成失败: {e}", exc_info=True) return triplet_maps - - async def _assign_extracted_data( - self, - dialog_data_list: List[DialogData], - temporal_maps: List[Dict[str, Any]], - triplet_maps: List[Dict[str, Any]], - emotion_maps: List[Dict[str, Any]], - statement_embedding_maps: List[Dict[str, List[float]]], - chunk_embedding_maps: List[Dict[str, List[float]]], - dialog_embeddings: List[List[float]], + self, + dialog_data_list: List[DialogData], + temporal_maps: List[Dict[str, Any]], + triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], + statement_embedding_maps: List[Dict[str, List[float]]], + chunk_embedding_maps: List[Dict[str, List[float]]], + dialog_embeddings: List[List[float]], ) -> List[DialogData]: """ 将提取的数据赋值到语句 @@ -906,12 +901,12 @@ class ExtractionOrchestrator: # 确保列表长度匹配 expected_length = len(dialog_data_list) if ( - len(temporal_maps) != expected_length - or len(triplet_maps) != expected_length - or len(emotion_maps) != expected_length - or len(statement_embedding_maps) != expected_length - or len(chunk_embedding_maps) != expected_length - or len(dialog_embeddings) != expected_length + len(temporal_maps) != expected_length + or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length + or len(statement_embedding_maps) != expected_length + or len(chunk_embedding_maps) != expected_length + or len(dialog_embeddings) != expected_length ): logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " @@ -999,7 +994,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _create_nodes_and_edges( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -1007,7 +1002,7 @@ class ExtractionOrchestrator: List[ExtractedEntityNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge], + List[EntityEntityEdge] ]: """ 创建图数据库节点和边 @@ -1021,7 +1016,7 @@ class ExtractionOrchestrator: 包含所有节点和边的元组 """ logger.info("开始创建节点和边") - + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] @@ -1034,7 +1029,7 @@ class ExtractionOrchestrator: # 用于去重的集合 entity_id_set = set() - + # 用于跟踪进度 total_dialogs = len(dialog_data_list) processed_dialogs = 0 @@ -1083,15 +1078,19 @@ class ExtractionOrchestrator: name=f"Statement_{statement.id}", # 添加必需的 name 字段 chunk_id=chunk.id, stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 - temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 - connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), + # 添加必需的 temporal_info 字段 + connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, - valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, + valid_at=statement.temporal_validity.valid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, + invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, @@ -1120,7 +1119,7 @@ class ExtractionOrchestrator: # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} - + # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): # 映射实体索引到实体ID(使用多个键以提高容错性) @@ -1128,7 +1127,7 @@ class ExtractionOrchestrator: entity_idx_to_id[entity.entity_idx] = entity.id # 2. 使用枚举索引(从0开始) entity_idx_to_id[entity_idx] = entity.id - + if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_node = ExtractedEntityNode( @@ -1141,7 +1140,8 @@ class ExtractionOrchestrator: example=getattr(entity, 'example', ''), # 新增:传递示例字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 @@ -1171,7 +1171,7 @@ class ExtractionOrchestrator: # 将三元组中的整数索引映射到实体ID subject_entity_id = entity_idx_to_id.get(triplet.subject_id) object_entity_id = entity_idx_to_id.get(triplet.object_id) - + # 只有当两个实体ID都存在时才创建边 if subject_entity_id and object_entity_id: entity_entity_edge = EntityEntityEdge( @@ -1186,7 +1186,7 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) - + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) if self.progress_callback and len(entity_entity_edges) <= 10: # 获取实体名称 @@ -1202,8 +1202,8 @@ class ExtractionOrchestrator: "dialog_progress": f"{processed_dialogs}/{total_dialogs}" } await self.progress_callback( - "creating_nodes_edges_result", - f"关系创建中 ({processed_dialogs}/{total_dialogs})", + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", relationship_result ) else: @@ -1211,7 +1211,7 @@ class ExtractionOrchestrator: missing_subject = "subject" if not subject_entity_id else "" missing_object = "object" if not object_entity_id else "" missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" - + logger.debug( f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " f"subject_id={triplet.subject_id} ({triplet.subject_name}), " @@ -1228,7 +1228,7 @@ class ExtractionOrchestrator: f"陈述句-实体边: {len(statement_entity_edges)}, " f"实体-实体边: {len(entity_entity_edges)}" ) - + # 进度回调:创建节点和边完成,传递结果统计 # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: @@ -1254,19 +1254,24 @@ class ExtractionOrchestrator: ) async def _run_dedup_and_write_summary( - self, - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 执行两阶段去重并写入汇总 @@ -1288,11 +1293,11 @@ class ExtractionOrchestrator: - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) """ logger.info("开始两阶段实体去重和消歧") - + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") - + logger.info( f"去重前: {len(entity_nodes)} 个实体节点, " f"{len(statement_entity_edges)} 条陈述句-实体边, " @@ -1307,7 +1312,7 @@ class ExtractionOrchestrator: from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( deduplicate_entities_and_edges, ) - + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, @@ -1317,10 +1322,10 @@ class ExtractionOrchestrator: dedup_config=self.config.deduplication, llm_client=self.llm_client, ) - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes) - + result_tuple = ( dialogue_nodes, chunk_nodes, @@ -1330,7 +1335,7 @@ class ExtractionOrchestrator: dedup_statement_entity_edges, dedup_entity_entity_edges, ) - + final_entity_nodes = dedup_entity_nodes final_statement_entity_edges = dedup_statement_entity_edges final_entity_entity_edges = dedup_entity_entity_edges @@ -1361,7 +1366,7 @@ class ExtractionOrchestrator: final_entity_entity_edges, dedup_details, ) = result_tuple - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) @@ -1375,12 +1380,12 @@ class ExtractionOrchestrator: f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { @@ -1391,10 +1396,10 @@ class ExtractionOrchestrator: "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { @@ -1407,14 +1412,13 @@ class ExtractionOrchestrator: "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) - + # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback( len(entity_nodes), len(final_entity_nodes), len(statement_entity_edges), len(final_statement_entity_edges), len(entity_entity_edges), len(final_entity_entity_edges) ) - # 写入提取结果汇总(试运行和正式模式都需要生成) try: @@ -1436,10 +1440,10 @@ class ExtractionOrchestrator: raise def _save_dedup_details( - self, - dedup_details: Dict[str, Any], - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ): """ 保存去重消歧的详细记录到实例变量(基于内存数据结构) @@ -1452,7 +1456,7 @@ class ExtractionOrchestrator: try: # 保存ID重定向映射 self.id_redirect_map = dedup_details.get("id_redirect", {}) - + # 处理精确匹配的合并记录 exact_merge_map = dedup_details.get("exact_merge_map", {}) for key, info in exact_merge_map.items(): @@ -1466,7 +1470,7 @@ class ExtractionOrchestrator: "merged_count": len(merged_ids), "merged_ids": list(merged_ids) }) - + # 处理模糊匹配的合并记录 fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", []) for record in fuzzy_merge_records: @@ -1486,7 +1490,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}") - + # 处理LLM去重的合并记录 llm_decision_records = dedup_details.get("llm_decision_records", []) for record in llm_decision_records: @@ -1505,7 +1509,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}") - + # 处理消歧记录 disamb_records = dedup_details.get("disamb_records", []) for record in disamb_records: @@ -1520,14 +1524,14 @@ class ExtractionOrchestrator: entity1_type = match.group(2) match.group(3).strip() entity2_type = match.group(4) - + # 提取置信度和原因 conf_match = re.search(r"conf=([0-9.]+)", str(record)) confidence = conf_match.group(1) if conf_match else "unknown" - + reason_match = re.search(r"reason=([^|]+)", str(record)) reason = reason_match.group(1).strip() if reason_match else "" - + self.dedup_disamb_records.append({ "entity_name": entity1_name, "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", @@ -1536,16 +1540,17 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") - - logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") - + + logger.info( + f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") + except Exception as e: logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) async def _analyze_entity_merges( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) @@ -1566,28 +1571,28 @@ class ExtractionOrchestrator: key=lambda x: x.get("merged_count", 0), reverse=True ) - + merge_info = [] for record in sorted_records: merge_info.append({ "main_entity_name": record.get("entity_name", "未知实体"), "merged_count": record.get("merged_count", 1) }) - + return merge_info - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体合并记录") return [] - + except Exception as e: logger.error(f"分析实体合并情况失败: {e}", exc_info=True) return [] async def _analyze_entity_disambiguation( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) @@ -1603,11 +1608,11 @@ class ExtractionOrchestrator: # 直接使用保存的消歧记录 if self.dedup_disamb_records: return self.dedup_disamb_records - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体消歧记录") return [] - + except Exception as e: logger.error(f"分析实体消歧情况失败: {e}", exc_info=True) return [] @@ -1624,7 +1629,7 @@ class ExtractionOrchestrator: """ type_mapping = { "Person": "人物实体节点", - "Organization": "组织实体节点", + "Organization": "组织实体节点", "ORG": "组织实体节点", "Location": "地点实体节点", "LOC": "地点实体节点", @@ -1645,9 +1650,9 @@ class ExtractionOrchestrator: return type_mapping.get(entity_type, f"{entity_type}实体节点") async def _output_relationship_creation_results( - self, - entity_entity_edges: List[EntityEntityEdge], - entity_nodes: List[ExtractedEntityNode] + self, + entity_entity_edges: List[EntityEntityEdge], + entity_nodes: List[ExtractedEntityNode] ): """ 输出关系创建结果 @@ -1659,13 +1664,13 @@ class ExtractionOrchestrator: try: # 创建实体ID到名称的映射 entity_id_to_name = {node.id: node.name for node in entity_nodes} - + # 输出关系创建结果 for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系 source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}") target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}") relation_type = edge.relation_type - + relationship_result = { "result_type": "relationship_creation", "relationship_index": i + 1, @@ -1674,20 +1679,20 @@ class ExtractionOrchestrator: "target_entity": target_name, "relationship_text": f"{source_name} -[{relation_type}]-> {target_name}" } - + await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result) - + except Exception as e: logger.error(f"输出关系创建结果失败: {e}", exc_info=True) async def _send_dedup_progress_callback( - self, - original_entities: int, - final_entities: int, - original_stmt_edges: int, - final_stmt_edges: int, - original_ent_edges: int, - final_ent_edges: int, + self, + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, ): """ 发送去重消歧完成的进度回调,传递具体的去重和消歧效果 @@ -1703,19 +1708,20 @@ class ExtractionOrchestrator: try: # 解析去重消歧报告文件,获取具体的去重和消歧效果 dedup_details = await self._parse_dedup_report() - + # 计算去重效果统计 entities_reduced = original_entities - final_entities stmt_edges_reduced = original_stmt_edges - final_stmt_edges ent_edges_reduced = original_ent_edges - final_ent_edges - + # 构建进度回调数据 dedup_stats = { "entities": { "original_count": original_entities, "final_count": final_entities, "reduced_count": entities_reduced, - "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, + "reduction_rate": round(entities_reduced / original_entities * 100, + 1) if original_entities > 0 else 0, }, "statement_entity_edges": { "original_count": original_stmt_edges, @@ -1734,9 +1740,9 @@ class ExtractionOrchestrator: "total_disambiguations": dedup_details.get("total_disambiguations", 0), } } - + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) - + except Exception as e: logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True) # 即使解析失败,也发送基本的统计信息 @@ -1766,12 +1772,12 @@ class ExtractionOrchestrator: disamb_examples = [] total_merges = 0 total_disambiguations = 0 - + # 处理合并记录 for record in self.dedup_merge_records: merge_count = record.get("merged_count", 0) total_merges += merge_count - + dedup_examples.append({ "type": record.get("type", "未知"), "entity_name": record.get("entity_name", "未知实体"), @@ -1779,30 +1785,31 @@ class ExtractionOrchestrator: "merge_count": merge_count, "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个" }) - + # 处理消歧记录 for record in self.dedup_disamb_records: total_disambiguations += 1 - + # 从消歧类型中提取实体类型信息 disamb_type = record.get("disamb_type", "") entity_name = record.get("entity_name", "未知实体") - + disamb_examples.append({ "entity1_name": entity_name, - "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", + "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", + "").strip() if "vs" in disamb_type else "未知", "entity2_name": entity_name, "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "description": f"{entity_name},消歧区分成功" }) - + return { "dedup_examples": dedup_examples[:5], # 只返回前5个示例 "disamb_examples": disamb_examples[:5], # 只返回前5个示例 "total_merges": total_merges, "total_disambiguations": total_disambiguations, } - + except Exception as e: logger.error(f"获取去重报告失败: {e}", exc_info=True) return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0} @@ -1815,9 +1822,9 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "group_1", - indices: Optional[List[int]] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "group_1", + indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 @@ -1831,7 +1838,7 @@ async def get_chunked_dialogs( """ import json import re - + # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") with open(testdata_path, "r", encoding="utf-8") as f: @@ -1845,7 +1852,7 @@ async def get_chunked_dialogs( else: # 默认使用所有数据 selected_data = test_data - + for data in selected_data: # 解析对话上下文 context_text = data["context"] @@ -1861,7 +1868,7 @@ async def get_chunked_dialogs( if m: y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - + dialog_metadata: Dict[str, Any] = {} if conv_date: dialog_metadata["conversation_date"] = conv_date @@ -1890,7 +1897,7 @@ async def get_chunked_dialogs( end_user_id=end_user_id, metadata=dialog_metadata, ) - + # 创建分块器并处理对话 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, @@ -1913,7 +1920,7 @@ async def get_chunked_dialogs( from app.core.config import settings settings.ensure_memory_output_dir() output_path = settings.get_memory_output_path("chunker_test_output.txt") - + import json with open(output_path, "w", encoding="utf-8") as f: json.dump( @@ -1924,10 +1931,10 @@ async def get_chunked_dialogs( def preprocess_data( - input_path: Optional[str] = None, - output_path: Optional[str] = None, - skip_cleaning: bool = True, - indices: Optional[List[int]] = None + input_path: Optional[str] = None, + output_path: Optional[str] = None, + skip_cleaning: bool = True, + indices: Optional[List[int]] = None ) -> List[DialogData]: """数据预处理 @@ -1946,7 +1953,8 @@ def preprocess_data( ) preprocessor = DataPreprocessor() try: - cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) + cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, + skip_cleaning=skip_cleaning, indices=indices) logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: @@ -1955,9 +1963,9 @@ def preprocess_data( async def get_chunked_dialogs_from_preprocessed( - data: List[DialogData], - chunker_strategy: str = "RecursiveChunker", - llm_client: Optional[Any] = None, + data: List[DialogData], + chunker_strategy: str = "RecursiveChunker", + llm_client: Optional[Any] = None, ) -> List[DialogData]: """从预处理后的数据中生成分块 @@ -1972,31 +1980,31 @@ async def get_chunked_dialogs_from_preprocessed( logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") - + all_chunked_dialogs: List[DialogData] = [] from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, ) - + for dialog_data in data: chunker = DialogueChunker(chunker_strategy, llm_client=llm_client) chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = chunks all_chunked_dialogs.append(dialog_data) - + return all_chunked_dialogs async def get_chunked_dialogs_with_preprocessing( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "default", - user_id: str = "default", - apply_id: str = "default", - indices: Optional[List[int]] = None, - input_data_path: Optional[str] = None, - llm_client: Optional[Any] = None, - skip_cleaning: bool = True, - pruning_config: Optional[Dict] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "default", + user_id: str = "default", + apply_id: str = "default", + indices: Optional[List[int]] = None, + input_data_path: Optional[str] = None, + llm_client: Optional[Any] = None, + skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2020,7 +2028,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path = os.path.join( os.path.dirname(__file__), "../../data", "testdata.json" ) - + # 步骤1: 数据预处理(包含索引筛选) from app.core.config import settings settings.ensure_memory_output_dir() @@ -2030,37 +2038,38 @@ async def get_chunked_dialogs_with_preprocessing( skip_cleaning=skip_cleaning, indices=indices, ) - + # 设置 end_user_id for dd in preprocessed_data: dd.end_user_id = end_user_id - + # 步骤2: 语义剪枝 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) from app.core.memory.models.config_models import PruningConfig - + # 构建剪枝配置 if pruning_config: # 使用传入的配置 config = PruningConfig(**pruning_config) - logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + logger.debug( + f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") else: # 使用默认配置(关闭剪枝) config = None logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") - + pruner = SemanticPruner(config=config, llm_client=llm_client) - + # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None if len(preprocessed_data) == 1 and preprocessed_data[0].context: single_dialog_original_msgs = len(preprocessed_data[0].context.msgs) preprocessed_data = await pruner.prune_dataset(preprocessed_data) - + # 单对话:打印清洗与剪枝信息 if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 @@ -2071,7 +2080,7 @@ async def get_chunked_dialogs_with_preprocessing( ) else: logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") - + # 保存剪枝后的数据 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( @@ -2084,7 +2093,7 @@ async def get_chunked_dialogs_with_preprocessing( logger.error(f"保存剪枝结果失败:{se}") except Exception as e: logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") - + # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( preprocessed_data, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 443ee36a..5e39ba36 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -188,7 +188,6 @@ async def _process_chunk_summary( response_model=MemorySummaryResponse, ) summary_text = structured.summary.strip() - # Generate title and type for the summary title = None episodic_type = None diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d4e1b488..60f1257e 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -374,7 +374,9 @@ class VariablePool: self.variables = deepcopy(pool.variables) def is_file_variable(self, selector): - variable_struct = self._get_variable_struct(selector) + variable_struct = self.get_instance(selector, default=None, strict=False) + if variable_struct is None: + return False if isinstance(variable_struct, FileVariable): return True elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0e3fecee..7f2b8aa6 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -623,7 +623,6 @@ class BaseNode(ABC): async def process_message( api_config: ModelInfo, content: str | dict | FileObject, - end_user_id: str, enable_file=False ) -> list | str | None: provider = api_config.provider @@ -642,8 +641,8 @@ class BaseNode(ABC): return content elif isinstance(content, FileObject): - if content.content_cache.get(provider): - return content.content_cache[provider] + if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): + return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] with get_db_read() as db: multimodel_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( @@ -655,12 +654,11 @@ class BaseNode(ABC): ) file_obj.set_content(content.get_content()) message = await multimodel_service.process_files( - end_user_id, [file_obj], ) content.set_content(file_obj.get_content()) if message: - content.content_cache[provider] = message + content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message return message return None raise TypeError(f'Unexpect input value type - {type(content)}') diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index b293d1f4..66a0f1ac 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -144,7 +144,6 @@ class LLMNode(BaseNode): f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages - if messages_config: # 使用 LangChain 消息格式 messages = [] @@ -153,7 +152,6 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ @@ -161,32 +159,31 @@ class LLMNode(BaseNode): "content": await self.process_message( model_info, content, - user_id, self.typed_config.vision, ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -200,7 +197,7 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(model_info, file, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( @@ -210,7 +207,6 @@ class LLMNode(BaseNode): message["content"] = await self.process_message( model_info, message["content"], - user_id, self.typed_config.vision ) history_message.append(message) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 82363056..cbdad0fa 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode): write_message_task.delay( end_user_id=end_user_id, message=messages, + file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 1095a386..616f7f3a 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -30,6 +30,9 @@ class MemoryConfig(Base): llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") + vision_id = Column(String, nullable=True, comment="视觉模型配置ID") + audio_id = Column(String, nullable=True, comment="语音模型配置ID") + video_id = Column(String, nullable=True, comment="视频模型配置ID") # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 5c2f81a7..6fb41914 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -9,21 +9,22 @@ Classes: """ import uuid -from uuid import UUID from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session + from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.memory_config_model import MemoryConfig +from app.models.workspace_model import Workspace from app.schemas.memory_storage_schema import ( - ConfigKey, ConfigParamsCreate, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc, select -from sqlalchemy.orm import Session - from app.utils.config_utils import resolve_config_id # 获取数据库专用日志器 @@ -157,7 +158,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -491,7 +492,10 @@ class MemoryConfigRepository: raise @staticmethod - def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: + def get_config_with_workspace( + db: Session, + config_id: uuid.UUID | int | str + ) -> Optional[tuple[MemoryConfig, Workspace]]: """Get memory config and its associated workspace information Args: @@ -506,8 +510,6 @@ class MemoryConfigRepository: """ import time - from app.models.workspace_model import Workspace - start_time = time.time() config_id = resolve_config_id(config_id, db) @@ -594,7 +596,7 @@ class MemoryConfigRepository: db_logger.debug( f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") - return (config, workspace) + return config, workspace except ValueError: # Re-raise known business exceptions @@ -630,7 +632,7 @@ class MemoryConfigRepository: List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ from app.models.ontology_scene import OntologyScene - + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: @@ -694,7 +696,7 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 默认配置对象,不存在则返回None """ db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}") - + try: # 优先查找显式标记为默认的配置 stmt = ( @@ -706,13 +708,13 @@ class MemoryConfigRepository: ) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"找到默认配置: config_id={config.config_id}") return config - + # 回退:获取最早创建的活跃配置 stmt = ( select(MemoryConfig) @@ -723,25 +725,25 @@ class MemoryConfigRepository: .order_by(MemoryConfig.created_at.asc()) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}") else: db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}") - + return config - + except Exception as e: db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}") raise @staticmethod def get_with_fallback( - db: Session, - config_id: Optional[uuid.UUID], - workspace_id: uuid.UUID + db: Session, + config_id: Optional[uuid.UUID], + workspace_id: uuid.UUID ) -> Optional[MemoryConfig]: """获取记忆配置,支持回退到工作空间默认配置 @@ -756,19 +758,18 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 配置对象,如果都不存在则返回None """ db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}") - + if not config_id: db_logger.debug("config_id 为空,使用工作空间默认配置") return MemoryConfigRepository.get_workspace_default(db, workspace_id) - + config = db.get(MemoryConfig, config_id) - + if config: return config - + db_logger.warning( f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}" ) - - return MemoryConfigRepository.get_workspace_default(db, workspace_id) + return MemoryConfigRepository.get_workspace_default(db, workspace_id) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..3a017089 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,7 +1,8 @@ from typing import List, Optional -from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode +from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ + MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result + async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. @@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC print(f"Error creating statement nodes: {e}") return None + async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. @@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> return None - -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: +async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ + List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) - + result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened @@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector return None +async def add_perceptual_nodes( + perceptuals: list, + connector: Neo4jConnector, + embedder_client=None, +) -> Optional[List[str]]: + """Add perceptual memory nodes to Neo4j in batch. + + Args: + perceptuals: List of MemoryPerceptualModel objects from PostgreSQL + connector: Neo4j connector instance + embedder_client: Optional embedder client for generating summary embeddings + + Returns: + List of created node UUIDs or None if failed + """ + if not perceptuals: + print("No perceptual nodes to add") + return [] + + try: + flattened = [] + for p in perceptuals: + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if embedder_client and p.summary: + try: + summary_embedding = (await embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + flattened.append({ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "summary_embedding": summary_embedding, + }) + + result = await connector.execute_query( + PERCEPTUAL_NODE_SAVE, + perceptuals=flattened, + ) + created_uuids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") + return created_uuids + + except Exception as e: + print(f"Failed to save Perceptual nodes to Neo4j: {e}") + return None + + +async def add_perceptual_dialogue_edges( + perceptuals: list, + dialog_id: str, + connector: Neo4jConnector, +) -> Optional[List[str]]: + """Add edges between Perceptual nodes and Dialogue nodes. + + Args: + perceptuals: List of MemoryPerceptualModel objects + dialog_id: The dialogue ID (or ref_id) to link to + connector: Neo4j connector instance + + Returns: + List of created edge element IDs or None if failed + """ + if not perceptuals or not dialog_id: + return [] + + try: + edges = [] + for p in perceptuals: + edges.append({ + "perceptual_id": str(p.id), + "dialog_id": dialog_id, + "end_user_id": str(p.end_user_id), + "created_at": p.created_time.isoformat() if p.created_time else None, + }) + + result = await connector.execute_query( + PERCEPTUAL_DIALOGUE_EDGE_SAVE, + edges=edges, + ) + created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") + return created_ids + + except Exception as e: + print(f"Failed to save Perceptual-Dialogue edges: {e}") + return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0ac7dcb1..49dbe2a5 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1323,3 +1323,36 @@ RETURN s.statement AS statement, ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit """ + +# 感知记忆节点保存 +PERCEPTUAL_NODE_SAVE = """ +UNWIND $perceptuals AS p +MERGE (n:Perceptual {id: p.id}) +SET n += { + id: p.id, + end_user_id: p.end_user_id, + perceptual_type: p.perceptual_type, + file_path: p.file_path, + file_name: p.file_name, + file_ext: p.file_ext, + summary: p.summary, + keywords: p.keywords, + topic: p.topic, + domain: p.domain, + created_at: p.created_at, + summary_embedding: p.summary_embedding +} +RETURN n.id AS uuid +""" + +# 感知记忆与对话的关联边 +PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +UNWIND $edges AS edge +MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) +MATCH (d:Dialogue {end_user_id: edge.end_user_id}) +WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id +MERGE (d)-[r:HAS_PERCEPTUAL]->(p) +SET r.end_user_id = edge.end_user_id, + r.created_at = edge.created_at +RETURN elementId(r) AS uuid +""" diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 8d7490fe..e186e54b 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -387,6 +387,12 @@ class MemoryConfig: rerank_model_id: Optional[UUID] = None rerank_model_name: Optional[str] = None + video_model_id: Optional[UUID] = None + video_model_name: Optional[str] = None + vision_model_id: Optional[UUID] = None + vision_model_name: Optional[str] = None + audio_model_id: Optional[UUID] = None + audio_model_name: Optional[str] = None llm_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..98f93408 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -141,7 +141,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -339,7 +339,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态),同时并行启动 TTS diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..f7331851 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -600,7 +600,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -836,7 +836,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 514cb12f..875f02bb 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID import redis -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session +from app.cache import InterestMemoryCache from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import ( merge_multiple_search_results, reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas import FileInput from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from app.services.memory_perceptual_service import MemoryPerceptualService try: from app.core.memory.utils.log.audit_logger import audit_logger @@ -271,6 +274,7 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], + file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -283,6 +287,7 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write + files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -342,48 +347,52 @@ class MemoryAgentService: raise ValueError(error_msg) + perceptual_serivce = MemoryPerceptualService(db) + file_content = [] + for message in file_messages: + for file in message["files"]: + file_object = await perceptual_serivce.generate_perceptual_memory( + end_user_id=end_user_id, + memory_config=memory_config, + file=FileInput(**file) + ) + file_content.append(file_object) + + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: if storage_type == "rag": # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) await write_rag(end_user_id, message_text, user_rag_memory_id) return "success" else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config, - "language": language - } - - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, - contents) + await write_neo4j( + end_user_id=end_user_id, + messages=messages, + file_content=file_content, + memory_config=memory_config, + ref_id='', + language=language + ) + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id, lang + ) + if deleted: + logger.info( + f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 01bc6267..9a0fb8ed 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -28,7 +28,7 @@ class MemoryAPIService: 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ - + def __init__(self, db: Session): """Initialize MemoryAPIService. @@ -36,11 +36,11 @@ class MemoryAPIService: db: SQLAlchemy database session """ self.db = db - + def validate_end_user( - self, - end_user_id: str, - workspace_id: uuid.UUID + self, + end_user_id: str, + workspace_id: uuid.UUID ) -> EndUser: """Validate that end_user exists and belongs to the workspace. @@ -56,7 +56,7 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace """ logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}") - + # Query end_user by ID try: end_user_uuid = uuid.UUID(end_user_id) @@ -66,7 +66,7 @@ class MemoryAPIService: message=f"Invalid end_user_id format: {end_user_id}", code=BizCode.INVALID_PARAMETER ) - + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() if not end_user: @@ -75,13 +75,13 @@ class MemoryAPIService: resource_type="EndUser", resource_id=end_user_id ) - + # Verify end_user belongs to the workspace via App relationship app = self.db.query(App).filter( App.id == end_user.app_id, App.is_active.is_(True) ).first() - + if not app: logger.warning(f"App not found for end_user: {end_user_id}") # raise ResourceNotFoundException( @@ -99,7 +99,7 @@ class MemoryAPIService: # message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}", # code=BizCode.FORBIDDEN # ) - + logger.info(f"End user {end_user_id} validated successfully") return end_user @@ -125,13 +125,14 @@ class MemoryAPIService: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") async def write_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - config_id: str, - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + files: Optional[list]=None, + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -153,14 +154,16 @@ class MemoryAPIService: ResourceNotFoundException: If end_user not found BusinessException: If end_user not in authorized workspace or write fails """ + if files is None: + files = list() logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - + try: # Delegate to MemoryAgentService # Convert string message to list[dict] format expected by MemoryAgentService @@ -171,11 +174,12 @@ class MemoryAPIService: config_id=config_id, db=self.db, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id or "" + user_rag_memory_id=user_rag_memory_id or "", + files=files ) - + logger.info(f"Memory write successful for end_user: {end_user_id}") - + # result may be a string "success" or a dict with a "status" key # Preserve the full dict so callers don't silently lose extra fields # (e.g. error codes, metadata) returned by MemoryAgentService. @@ -189,7 +193,7 @@ class MemoryAPIService: "status": result if isinstance(result, str) else "success", "end_user_id": end_user_id, } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -204,16 +208,16 @@ class MemoryAPIService: message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - + async def read_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - search_switch: str = "0", - config_id: str = "", - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Read memory with validation. @@ -237,14 +241,13 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace or read fails """ logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( @@ -257,15 +260,15 @@ class MemoryAPIService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "" ) - + logger.info(f"Memory read successful for end_user: {end_user_id}") - + return { "answer": result.get("answer", ""), "intermediate_outputs": result.get("intermediate_outputs", []), "end_user_id": end_user_id } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -282,8 +285,8 @@ class MemoryAPIService: ) def list_memory_configs( - self, - workspace_id: uuid.UUID, + self, + workspace_id: uuid.UUID, ) -> Dict[str, Any]: """List all memory configs for a workspace. diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index a3751c07..1a4af531 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None): """Validate configuration ID format (supports both UUID and integer).""" if isinstance(config_id, uuid.UUID): return config_id - + if config_id is None: raise InvalidConfigError( "Configuration ID cannot be None", @@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None): if result: logger.info(f"Found config_id {result.config_id} for user_id {config_id}") return result.config_id - + return config_id if isinstance(config_id, str): config_id_stripped = config_id.strip() - + # Try parsing as UUID first try: return uuid.UUID(config_id_stripped) except ValueError: pass - + # Fall back to integer parsing try: parsed_id = int(config_id_stripped) @@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - + # 如果提供了数据库会话,尝试通过 user_id 查询 config_id if db is not None: # 查询 user_id 匹配的记录 stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id)) result = db.execute(stmt).scalars().first() - + if result: logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}") return result.config_id - + return parsed_id except ValueError: raise InvalidConfigError( @@ -154,10 +154,10 @@ class MemoryConfigService: self.db = db def load_memory_config( - self, - config_id: Optional[UUID] = None, - workspace_id: Optional[UUID] = None, - service_name: str = "MemoryConfigService", + self, + config_id: Optional[UUID] = None, + workspace_id: Optional[UUID] = None, + service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database with optional fallback. @@ -194,14 +194,14 @@ class MemoryConfigService: try: # Use get_config_with_fallback if workspace_id is provided memory_config = None + validated_config_id = None if workspace_id: - validated_config_id = None if config_id: try: validated_config_id = _validate_config_id(config_id, self.db) except Exception: validated_config_id = None - + memory_config = self.get_config_with_fallback( memory_config_id=validated_config_id, workspace_id=workspace_id @@ -210,7 +210,7 @@ class MemoryConfigService: validated_config_id = _validate_config_id(config_id, self.db) from app.models.memory_config_model import MemoryConfig as MemoryConfigModel memory_config = self.db.get(MemoryConfigModel, validated_config_id) - + if not memory_config: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -233,7 +233,7 @@ class MemoryConfigService: result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id) db_query_time = time.time() - db_query_start logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") - + if not result: raise ConfigurationError( f"Workspace not found for config {memory_config.config_id}" @@ -243,10 +243,10 @@ class MemoryConfigService: # Helper function to validate model with workspace fallback def _validate_model_with_fallback( - model_id: str, - model_type: str, - workspace_default: str, - required: bool = False + model_id: str, + model_type: str, + workspace_default: str, + required: bool = False ) -> tuple: """Validate model ID, falling back to workspace default if invalid. @@ -275,7 +275,7 @@ class MemoryConfigService: logger.warning( f"{model_type} model validation failed, trying workspace default: {e}" ) - + # Fallback to workspace default if workspace_default: try: @@ -297,7 +297,7 @@ class MemoryConfigService: logger.error(f"Workspace default {model_type} model also invalid: {e}") if required: raise - + if required: raise InvalidConfigError( f"{model_type.title()} model is required but not configured", @@ -306,7 +306,7 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id ) - + return None, None # Step 2: Validate embedding model with workspace fallback @@ -343,6 +343,35 @@ class MemoryConfigService: if memory_config.rerank_id or workspace.rerank: logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") + vision_uuid, vision_name = validate_and_resolve_model_id( + memory_config.vision_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + audio_uuid, audio_name = validate_and_resolve_model_id( + memory_config.audio_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + video_uuid, video_name = validate_and_resolve_model_id( + memory_config.video_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) # Create immutable MemoryConfig object config = MemoryConfig( config_id=memory_config.config_id, @@ -356,6 +385,12 @@ class MemoryConfigService: embedding_model_name=embedding_name, rerank_model_id=rerank_uuid, rerank_model_name=rerank_name, + video_model_id=video_uuid, + video_model_name=video_name, + vision_model_id=vision_uuid, + vision_model_name=vision_name, + audio_model_id=audio_uuid, + audio_model_name=audio_name, storage_type=workspace.storage_type or "neo4j", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", reflexion_enabled=memory_config.enable_self_reflexion or False, @@ -364,24 +399,31 @@ class MemoryConfigService: reflexion_baseline=memory_config.baseline or "Time", loaded_at=datetime.now(), # Pipeline config: Deduplication - enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, - enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, + enable_llm_dedup_blockwise=bool( + memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, + enable_llm_disambiguation=bool( + memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, # Pipeline config: Statement extraction - statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, - include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, - max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, + statement_granularity=int( + memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, + include_dialogue_context=bool( + memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, + max_dialogue_context_chars=int( + memory_config.max_context) if memory_config.max_context is not None else 1000, # Pipeline config: Forgetting engine lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, # Pipeline config: Pruning - pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, + pruning_enabled=bool( + memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_scene=memory_config.pruning_scene or "education", - pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, + pruning_threshold=float( + memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), @@ -448,9 +490,9 @@ class MemoryConfigService: if not config: logger.warning(f"Model ID {model_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -481,9 +523,9 @@ class MemoryConfigService: if not config: logger.warning(f"Embedding model ID {embedding_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -571,25 +613,25 @@ class MemoryConfigService: """ from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.repositories.ontology_class_repository import OntologyClassRepository - + if not memory_config.scene_id: logger.debug("No scene_id configured, skipping ontology type fetch") return None - + try: ontology_repo = OntologyClassRepository(self.db) ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id) - + if not ontology_classes: logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}") return None - + ontology_types = OntologyTypeList.from_db_models(ontology_classes) logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" ) return ontology_types - + except Exception as e: logger.warning( f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}", @@ -598,8 +640,8 @@ class MemoryConfigService: return None def get_workspace_default_config( - self, - workspace_id: UUID + self, + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get workspace default memory config. @@ -613,19 +655,19 @@ class MemoryConfigService: Optional[MemoryConfigModel]: Default config or None if no configs exist """ config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id) - + if not config: logger.warning( "No active memory config found for workspace fallback", extra={"workspace_id": str(workspace_id)} ) - + return config def get_config_with_fallback( - self, - memory_config_id: Optional[UUID], - workspace_id: UUID + self, + memory_config_id: Optional[UUID], + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get memory config with fallback to workspace default. @@ -644,13 +686,13 @@ class MemoryConfigService: "No memory config ID provided, using workspace default", extra={"workspace_id": str(workspace_id)} ) - + config = MemoryConfigRepository.get_with_fallback( self.db, memory_config_id, workspace_id ) - + if not config and memory_config_id: logger.warning( "Memory config not found, falling back to workspace default", @@ -659,13 +701,13 @@ class MemoryConfigService: "workspace_id": str(workspace_id) } ) - + return config def delete_config( - self, - config_id: UUID | int, - force: bool = False + self, + config_id: UUID | int, + force: bool = False ) -> dict: """Delete memory config with protection against in-use configs. @@ -687,7 +729,7 @@ class MemoryConfigService: from app.core.exceptions import ResourceNotFoundException from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.repositories.end_user_repository import EndUserRepository - + # 处理旧格式 int 类型的 config_id if isinstance(config_id, int): logger.warning( @@ -699,11 +741,11 @@ class MemoryConfigService: "message": "旧格式配置ID不支持删除操作,请使用新版配置", "legacy_int_id": config_id } - + config = self.db.get(MemoryConfigModel, config_id) if not config: raise ResourceNotFoundException("MemoryConfig", str(config_id)) - + # Check if this is the default config - default configs cannot be deleted if config.is_default: logger.warning( @@ -715,11 +757,11 @@ class MemoryConfigService: "message": "默认配置不允许删除", "is_default": True } - + # Use repository to count connected end users end_user_repo = EndUserRepository(self.db) connected_count = end_user_repo.count_by_memory_config_id(config_id) - + if connected_count > 0 and not force: logger.warning( "Attempted to delete memory config with connected end users", @@ -728,18 +770,18 @@ class MemoryConfigService: "connected_count": connected_count } ) - + return { "status": "warning", "message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置", "connected_count": connected_count, "force_required": True } - + # Force delete: use repository to clear end user references first if connected_count > 0 and force: cleared_count = end_user_repo.clear_memory_config_id(config_id) - + logger.warning( "Force deleting memory config, clearing end user references", extra={ @@ -747,11 +789,11 @@ class MemoryConfigService: "cleared_end_users": cleared_count } ) - + try: self.db.delete(config) self.db.commit() - + logger.info( "Memory config deleted", extra={ @@ -760,16 +802,16 @@ class MemoryConfigService: "affected_users": connected_count } ) - + return { "status": "success", "message": "记忆配置删除成功", "affected_users": connected_count } - + except IntegrityError as e: self.db.rollback() - + # Handle foreign key violation gracefully error_str = str(e.orig) if e.orig else str(e) if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower(): @@ -785,7 +827,7 @@ class MemoryConfigService: "message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除", "force_required": True } - + # Re-raise other integrity errors logger.error( "Delete failed due to integrity error", @@ -800,9 +842,9 @@ class MemoryConfigService: # ==================== 记忆配置提取方法 ==================== def extract_memory_config_id( - self, - app_type: str, - config: dict + self, + app_type: str, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(根据应用类型分发) @@ -828,8 +870,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_agent( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Agent 应用配置中提取 memory_config_id @@ -888,8 +930,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_workflow( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Workflow 应用配置中提取 memory_config_id @@ -905,14 +947,14 @@ class MemoryConfigService: - is_legacy_int: 是否检测到旧格式 int 数据 """ nodes = config.get("nodes", []) - + for node in nodes: node_type = node.get("type", "") - + # 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite) if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]: config_id = node.get("config", {}).get("config_id") - + if config_id: try: # 处理字符串、UUID 和 int(旧数据兼容)三种情况 @@ -937,6 +979,6 @@ class MemoryConfigService: f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " f"node_type={node_type}, error={str(e)}" ) - + logger.debug("工作流配置中未找到记忆节点") return None, False diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8a7c86e2..d6c1de87 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -12,11 +12,12 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models import FileMetadata +from app.models import FileMetadata, ModelApiKey, ModelType from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository -from app.schemas import FileType +from app.schemas import FileType, FileInput +from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, @@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import ( AudioModal, Content, VideoModal, TextModal ) from app.schemas.model_schema import ModelInfo +from app.services.model_service import ModelApiKeyService +from app.services.multimodal_service import MultimodalService business_logger = get_business_logger() @@ -195,21 +198,58 @@ class MemoryPerceptualService: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + def _get_mutlimodal_client( + self, + file_type: FileType, + config: MemoryConfig + ) -> tuple[RedBearLLM | None, ModelApiKey | None]: + model_config = None + if file_type == FileType.AUDIO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.audio_model_id + ) + elif file_type == FileType.VIDEO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.video_model_id + ) + elif file_type == FileType.DOCUMENT: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.llm_model_id + ) + elif file_type == FileType.IMAGE: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.vision_model_id + ) + llm = None + if model_config: + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ) + ) + return llm, model_config + async def generate_perceptual_memory( self, end_user_id: str, - model_config: ModelInfo, - file_type: str, - file_url: str, - file_message: dict, + memory_config: MemoryConfig, + file: FileInput ): - memories = self.repository.get_by_url(file_url) + memories = self.repository.get_by_url(file.url) if memories: - business_logger.info(f"Perceptual memory already exists: {file_url}") + business_logger.info(f"Perceptual memory already exists: {file.url}") if end_user_id not in [memory.end_user_id for memory in memories]: business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") memory_cache = memories[0] - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), perceptual_type=PerceptualType(memory_cache.perceptual_type), file_path=memory_cache.file_path, @@ -219,20 +259,31 @@ class MemoryPerceptualService: meta_data=memory_cache.meta_data ) self.db.commit() - - return - llm = RedBearLLM(RedBearModelConfig( + return memory + else: + for memory in memories: + if memory.end_user_id == uuid.UUID(end_user_id): + return memory + llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, api_key=model_config.api_key, - base_url=model_config.api_base, - is_omni=model_config.is_omni - ), type=model_config.model_type) + api_base=model_config.api_base, + is_omni=model_config.is_omni, + capability=model_config.capability, + model_type=ModelType.LLM + )) + file_message = await multimodel_service.process_files( + files=[file] + ) + if file_message: + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) messages = [ @@ -242,8 +293,22 @@ class MemoryPerceptualService: ]} ] result = await llm.ainvoke(messages) - content = json_repair.repair_json(result.content, return_objects=True) - path = urlparse(file_url).path + content = result.content + final_output = "" + if isinstance(content, list): + for msg in content: + if isinstance(msg, dict): + final_output += msg.get("text", "") + elif isinstance(msg, str): + final_output += msg + elif isinstance(content, dict): + final_output += content.get("text", "") + elif isinstance(content, str): + final_output = content + else: + raise ValueError(f"Unexcept Model Output Type: {result.content}") + content = json_repair.repair_json(final_output, return_objects=True) + path = urlparse(file.url).path filename = os.path.basename(path) filename = unquote(filename) file_ext = os.path.splitext(filename)[1] @@ -260,13 +325,13 @@ class MemoryPerceptualService: except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: - if file_type == FileType.AUDIO: + if file.type == FileType.AUDIO: file_ext = ".mp3" - elif file_type == FileType.VIDEO: + elif file.type == FileType.VIDEO: file_ext = ".mp4" - elif file_type == FileType.DOCUMENT: + elif file.type == FileType.DOCUMENT: file_ext = ".txt" - elif file_type == FileType.IMAGE: + elif file.type == FileType.IMAGE: file_ext = ".jpg" filename += file_ext file_content = { @@ -274,11 +339,11 @@ class MemoryPerceptualService: "topic": content.get("topic"), "domain": content.get("domain") } - if file_type in [FileType.IMAGE, FileType.VIDEO]: + if file.type in [FileType.IMAGE, FileType.VIDEO]: file_modalities = { "scene": content.get("scene", []) } - elif file_type in [FileType.DOCUMENT]: + elif file.type in [FileType.DOCUMENT]: file_modalities = { "section_count": content.get("section_count", 0), "title": content.get("title", ""), @@ -288,10 +353,10 @@ class MemoryPerceptualService: file_modalities = { "speaker_count": content.get("speaker_count", 0) } - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType.trans_from_file_type(file_type), - file_path=file_url, + perceptual_type=PerceptualType.trans_from_file_type(file.type), + file_path=file.url, file_name=filename, file_ext=file_ext, summary=content.get('summary', ""), @@ -301,3 +366,4 @@ class MemoryPerceptualService: } ) self.db.commit() + return memory diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index f0c7cee2..eb8df242 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,14 +9,12 @@ - OpenAI: 支持 URL 和 base64 格式 """ import base64 +import csv import io -import uuid +import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional -import csv -import json - import PyPDF2 import httpx import magic @@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService -from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -342,15 +339,12 @@ class MultimodalService: async def process_files( self, - end_user_id: uuid.UUID | str, files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: - end_user_id: 用户ID files: 文件输入列表 Returns: @@ -358,8 +352,6 @@ class MultimodalService: """ if not files: return [] - if isinstance(end_user_id, uuid.UUID): - end_user_id = str(end_user_id) # 获取对应的策略 # dashscope 的 omni 模型使用 OpenAI 兼容格式 @@ -380,23 +372,15 @@ class MultimodalService: if file.type == FileType.IMAGE and "vision" in self.capability: is_support, content = await self._process_image(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: is_support, content = await self._process_document(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.AUDIO and "audio" in self.capability: is_support, content = await self._process_audio(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.VIDEO and "video" in self.capability: is_support, content = await self._process_video(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -418,17 +402,6 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - def write_perceptual_memory( - self, - end_user_id: str, - file_type: str, - file_url: str, - file_message: dict - ): - """写入感知记忆""" - if end_user_id and self.api_config: - write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) - async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 diff --git a/api/app/tasks.py b/api/app/tasks.py index c37e564e..8afb2194 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,12 +1080,14 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, + file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write + file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1097,6 +1099,8 @@ def write_message_task( Raises: Exception on failure """ + if file_messages is None: + file_messages = [] logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " @@ -1142,7 +1146,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result From 6bba574ca64ff23570826b34e61246e6334918c0 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 13:54:15 +0800 Subject: [PATCH 005/120] feat(memory, model): update multi-modal memory write and model list API - Adjust multi-modal memory write behavior for text and visual data - Mask API keys in model list response to prevent exposure - Add capability-based filtering to the model list API --- api/app/controllers/model_controller.py | 12 ++ .../controllers/user_memory_controllers.py | 108 ++++++++------ .../core/memory/agent/utils/get_dialogs.py | 5 +- .../core/memory/agent/utils/write_tools.py | 40 +---- .../core/memory/llm_tools/chunker_client.py | 10 +- api/app/core/memory/models/graph_models.py | 33 +++-- api/app/core/memory/models/message_models.py | 4 +- .../extraction_orchestrator.py | 80 +++++++++- api/app/core/workflow/nodes/memory/node.py | 9 +- api/app/models/models_model.py | 5 +- .../repositories/memory_config_repository.py | 3 + api/app/repositories/model_repository.py | 12 +- api/app/repositories/neo4j/add_nodes.py | 139 +++--------------- api/app/repositories/neo4j/cypher_queries.py | 109 +++++++------- api/app/repositories/neo4j/graph_saver.py | 72 ++++++--- api/app/schemas/model_schema.py | 38 +++-- api/app/services/memory_agent_service.py | 39 ++--- api/app/services/memory_perceptual_service.py | 8 +- api/app/services/pilot_run_service.py | 3 + api/app/services/user_memory_service.py | 3 +- api/app/tasks.py | 58 +------- 21 files changed, 389 insertions(+), 401 deletions(-) diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 6204a745..71fd41ad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -42,6 +42,7 @@ def get_model_strategies(): @router.get("", response_model=ApiResponse) def get_model_list( type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -74,10 +75,21 @@ def get_model_list( unique_flat_type = list(dict.fromkeys(flat_type)) type_list = [ModelType(t.lower()) for t in unique_flat_type] + capability_list = [] + if capability is not None: + flat_capability = [] + for item in capability: + split_items = [c.strip() for c in item.split(', ') if c.strip()] + flat_capability.extend(split_items) + + unique_flat_capability = list(dict.fromkeys(flat_capability)) + capability_list = unique_flat_capability + api_logger.error(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQuery( type=type_list, provider=provider, + capability=capability_list, is_active=is_active, is_public=is_public, search=search, diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index be796ff9..3ce1df6e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -5,7 +5,7 @@ from typing import Optional import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends,Header +from fastapi import APIRouter, Depends, Header from app.db import get_db from app.core.language_utils import get_language_from_header @@ -19,7 +19,7 @@ from app.services.user_memory_service import ( analytics_graph_data, analytics_community_graph_data, ) -from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction +from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository @@ -45,9 +45,9 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的记忆洞察报告 @@ -73,10 +73,10 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的用户摘要 @@ -90,7 +90,7 @@ async def get_user_summary_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -102,7 +102,7 @@ async def get_user_summary_api( api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) + result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -117,10 +117,10 @@ async def get_user_summary_api( @router.post("/analytics/generate_cache", response_model=ApiResponse) async def generate_cache_api( - request: GenerateCacheRequest, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + request: GenerateCacheRequest, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 手动触发缓存生成 @@ -134,7 +134,7 @@ async def generate_cache_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -155,10 +155,12 @@ async def generate_cache_api( api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") # 生成记忆洞察 - insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) + insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, + language=language) # 生成用户摘要 - summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) + summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, + language=language) # 构建响应 result = { @@ -209,9 +211,9 @@ async def generate_cache_api( @router.get("/analytics/node_statistics", response_model=ApiResponse) async def get_node_statistics_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -220,7 +222,8 @@ async def get_node_statistics_api( api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + api_logger.info( + f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") try: # 调用新的记忆类型统计函数 @@ -228,21 +231,23 @@ async def get_node_statistics_api( # 计算总数用于日志 total_count = sum(item["count"] for item in result) - api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") + api_logger.info( + f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) + @router.get("/analytics/graph_data", response_model=ApiResponse) async def get_graph_data_api( - end_user_id: str, - node_types: Optional[str] = None, - limit: int = 100, - depth: int = 1, - center_node_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -298,9 +303,9 @@ async def get_graph_data_api( @router.get("/analytics/community_graph", response_model=ApiResponse) async def get_community_graph_data_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -334,9 +339,9 @@ async def get_community_graph_data_api( @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) @@ -385,9 +390,9 @@ async def get_end_user_profile( @router.post("/updated_end_user/profile", response_model=ApiResponse) async def update_end_user_profile( - profile_update: EndUserProfileUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 更新终端用户的基本信息 @@ -417,7 +422,7 @@ async def update_end_user_profile( else: error_msg = result["error"] api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") - + # 根据错误类型映射到合适的业务错误码 if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) @@ -427,15 +432,18 @@ async def update_end_user_profile( # 只有未预期的错误才使用 INTERNAL_ERROR return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) + @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): +async def memory_space_timeline_of_shared_memories( + id: str, label: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): # 使用集中化的语言校验 language = get_language_from_header(language_type) - - workspace_id=current_user.current_workspace_id + + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_ timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) return success(data=timeline_memories_result, msg="共同记忆时间线") + + @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): try: api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 3b06defe..4c667061 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -11,7 +11,7 @@ async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", end_user_id: str = "group_1", messages: list = None, - ref_id: str = "wyl_20251027", + ref_id: str = "", config_id: str = None ) -> List[DialogData]: """Generate chunks from structured messages using the specified chunker strategy. @@ -40,12 +40,13 @@ async def get_chunked_dialogs( role = msg['role'] content = msg['content'] + files = msg.get("file_content", []) if role not in ['user', 'assistant']: raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") if content.strip(): - conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) if not conversation_messages: raise ValueError("Message list cannot be empty after filtering") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 147a0316..413f54da 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,8 +5,8 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio -import uuid import time +import uuid from datetime import datetime from dotenv import load_dotenv @@ -19,10 +19,8 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context -from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ - add_perceptual_dialogue_edges +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -36,7 +34,6 @@ async def write( end_user_id: str, memory_config: MemoryConfig, messages: list, - file_content: list[MemoryPerceptualModel], ref_id: str = "", language: str = "zh", ) -> None: @@ -47,7 +44,6 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - file_content: mutilmodal message list ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ @@ -142,9 +138,11 @@ async def write( all_chunk_nodes, all_statement_nodes, all_entity_nodes, + all_perceptual_nodes, all_statement_chunk_edges, all_statement_entity_edges, all_entity_entity_edges, + all_perceptual_edges, all_dedup_details, ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) @@ -169,9 +167,11 @@ async def write( chunk_nodes=all_chunk_nodes, statement_nodes=all_statement_nodes, entity_nodes=all_entity_nodes, + perceptual_nodes=all_perceptual_nodes, statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, + perceptual_edges=all_perceptual_edges, connector=neo4j_connector, ) if success: @@ -230,34 +230,6 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) - # Step 5: Save perceptual memory to Neo4j - step_start = time.time() - if file_content: - try: - pc_connector = Neo4jConnector() - try: - created_ids = await add_perceptual_nodes( - perceptuals=file_content, - connector=pc_connector, - embedder_client=embedder_client, - ) - # 如果有 ref_id,建立感知记忆与对话的关联 - if ref_id and created_ids: - await add_perceptual_dialogue_edges( - perceptuals=file_content, - dialog_id=ref_id, - connector=pc_connector, - ) - logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") - finally: - try: - await pc_connector.close() - except Exception: - pass - except Exception as e: - logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) - log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) - # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 93a2df82..51d15aab 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -1,10 +1,10 @@ -from typing import Any, List -import re -import os import asyncio import json -import numpy as np import logging +import os +from typing import Any, List + +import numpy as np # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -246,6 +246,7 @@ class ChunkerClient: "total_sub_chunks": len(sub_chunks), "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) else: @@ -258,6 +259,7 @@ class ChunkerClient: "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index fb251f1f..1b8c9d52 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -114,7 +114,7 @@ class Edge(BaseModel): end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") + expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.") class ChunkEdge(Edge): @@ -175,6 +175,12 @@ class EntityEntityEdge(Edge): return parse_historical_datetime(v) +class PerceptualEdge(Edge): + """Edge connecting perceptual nodes to their source chunks + """ + pass + + class Node(BaseModel): """Base class for all graph nodes in the knowledge graph. @@ -555,19 +561,16 @@ class MemorySummaryNode(Node): ) -class MutlimodalNode(Node): +class PerceptualNode(Node): """Node representing a multimodal message in the knowledge graph. - - Attributes: - dialog_id: ID of the parent dialog - message_id: ID of the message - metadata: Additional message metadata - embedding: Optional embedding vector for the message """ - dialog_id: str = Field(..., description="ID of the parent dialog") - message_id: str = Field(..., description="ID of the message") - summary: str = Field(..., description="The text content of the message") - file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") - file_path: List[str] = Field(..., description="List of file paths for multimodal content") - metadata: dict = Field(default_factory=dict, description="Additional message metadata") - embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") + perceptual_type: int + file_path: str + file_name: str + file_ext: str + summary: str + keywords: list[str] + topic: str + domain: str + file_type: str + summary_embedding: list[float] | None diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 2f8660af..66203067 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -30,6 +30,7 @@ class ConversationMessage(BaseModel): """ role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") + files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True) class TemporalValidityRange(BaseModel): @@ -130,7 +131,8 @@ class Chunk(BaseModel): content: str = Field(..., description="The content of the chunk as a string.") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") - chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") + files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") + chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 6e94a84f..da10c497 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,7 +31,9 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode + StatementNode, + PerceptualEdge, + PerceptualNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -170,9 +172,11 @@ class ExtractionOrchestrator: list[ChunkNode], list[StatementNode], list[ExtractedEntityNode], + list[PerceptualNode], list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[PerceptualEdge], dict ]: """ @@ -259,9 +263,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) = await self._create_nodes_and_edges(dialog_data_list) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) @@ -275,7 +281,16 @@ class ExtractionOrchestrator: # 注意:deduplication 消息已在创建节点和边完成后立即发送 - result = await self._run_dedup_and_write_summary( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + dialog_data_list, + ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, statement_nodes, @@ -287,7 +302,18 @@ class ExtractionOrchestrator: ) logger.info(f"知识提取流水线运行完成({mode_str})") - return result + return ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + perceptual_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + perceptual_edges, + dialog_data_list, + ) except Exception as e: logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) @@ -1000,9 +1026,11 @@ class ExtractionOrchestrator: List[ChunkNode], List[StatementNode], List[ExtractedEntityNode], + List[PerceptualNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge] + List[EntityEntityEdge], + List[PerceptualEdge] ]: """ 创建图数据库节点和边 @@ -1026,6 +1054,8 @@ class ExtractionOrchestrator: statement_chunk_edges = [] statement_entity_edges = [] entity_entity_edges = [] + perceptual_nodes = [] + perceptual_edges = [] # 用于去重的集合 entity_id_set = set() @@ -1069,6 +1099,46 @@ class ExtractionOrchestrator: metadata=chunk.metadata, ) chunk_nodes.append(chunk_node) + logger.error(f"chunk file: {chunk.files}") + + for p, file_type in chunk.files: + + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if self.embedder_client and p.summary: + try: + summary_embedding = (await self.embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + perceptual = PerceptualNode( + name=f"Perceptual_{p.id}", + **{ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "file_type": file_type, + "summary_embedding": summary_embedding, + }) + perceptual_nodes.append(perceptual) + perceptual_edges.append(PerceptualEdge( + source=perceptual.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + )) # 处理每个陈述句 for statement in chunk.statements: @@ -1248,9 +1318,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) async def _run_dedup_and_write_summary( diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index cbdad0fa..a28247e4 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -72,7 +72,6 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] - multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -104,19 +103,15 @@ class MemoryWriteNode(BaseNode): url=file_instence.value.url, file_type=file_instence.value.origin_file_type ).model_dump()) - multimodal_memories.append({ - "role": message.role, - "files": file_info - }) messages.append({ "role": message.role, - "content": self._render_template(content, variable_pool) + "content": self._render_template(content, variable_pool), + "files": file_info }) write_message_task.delay( end_user_id=end_user_id, message=messages, - file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcef..44a844d0 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,10 +2,11 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text +from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY from sqlalchemy.orm import relationship from sqlalchemy.sql import func + from app.db import Base diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 6fb41914..e64d19a3 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -408,6 +408,9 @@ class MemoryConfigRepository: "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, "rerank_id": db_config.rerank_id, + "vision_id": db_config.vision_id, + "audio_id": db_config.audio_id, + "video_id": db_config.video_id, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_disambiguation": db_config.enable_llm_disambiguation, "deep_retrieval": db_config.deep_retrieval, diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f49227d3..fd95c793 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -1,14 +1,15 @@ -from sqlalchemy.orm import Session, joinedload, selectinload -from sqlalchemy import and_, or_, func, desc, select -from typing import List, Optional, Dict, Any, Tuple import uuid +from typing import List, Optional, Dict, Any, Tuple +from sqlalchemy import and_, or_, func, desc +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.schemas.model_schema import ( ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigQuery, ModelConfigQueryNew ) -from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() @@ -137,6 +138,9 @@ class ModelConfigRepository: type_values.append(ModelType.LLM) filters.append(ModelConfig.type.in_(type_values)) + if query.capability: + filters.append(ModelConfig.capability.contains(query.capability)) + if query.is_active is not None: filters.append(ModelConfig.is_active == query.is_active) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 3a017089..a53ca289 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,16 +1,19 @@ from typing import List, Optional +from app.core.logging_config import get_logger from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ - MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE + MEMORY_SUMMARY_NODE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = get_logger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") - print(f"All end_user_id: {end_user_id} node and edge deleted successfully") + logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result @@ -25,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn List of created node UUIDs or None if failed """ if not dialogues: - print("No dialogues to save") + logger.info("No dialogues to save") return [] try: @@ -50,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") + logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: - print(f"Error creating dialogue nodes: {e}") + logger.info(f"Error creating dialogue nodes: {e}") return None @@ -69,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC List of created node UUIDs or None if failed """ if not statements: - print("No statements to save") + logger.info("No statements to save") return [] try: @@ -122,11 +125,11 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} statement nodes") + logger.info(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: - print(f"Error creating statement nodes: {e}") + logger.info(f"Error creating statement nodes: {e}") return None @@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> List of created chunk UUIDs or None if failed """ if not chunks: - print("No chunk nodes to add") + logger.info("No chunk nodes to add") return [] try: @@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} chunk nodes") + logger.info(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: - print(f"Error creating chunk nodes: {e}") + logger.info(f"Error creating chunk nodes: {e}") return None -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ - List[str]]: +async def add_memory_summary_nodes( + summaries: List[MemorySummaryNode], + connector: Neo4jConnector +) -> Optional[List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector List of created summary node ids or None if failed """ if not summaries: - print("No memory summary nodes to add") + logger.info("No memory summary nodes to add") return [] try: @@ -220,110 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") - return None - - -async def add_perceptual_nodes( - perceptuals: list, - connector: Neo4jConnector, - embedder_client=None, -) -> Optional[List[str]]: - """Add perceptual memory nodes to Neo4j in batch. - - Args: - perceptuals: List of MemoryPerceptualModel objects from PostgreSQL - connector: Neo4j connector instance - embedder_client: Optional embedder client for generating summary embeddings - - Returns: - List of created node UUIDs or None if failed - """ - if not perceptuals: - print("No perceptual nodes to add") - return [] - - try: - flattened = [] - for p in perceptuals: - meta = p.meta_data or {} - content_meta = meta.get("content", {}) - - # 生成 summary embedding(如果有 embedder_client) - summary_embedding = None - if embedder_client and p.summary: - try: - summary_embedding = (await embedder_client.response([p.summary]))[0] - except Exception as emb_err: - print(f"Failed to embed perceptual summary: {emb_err}") - - flattened.append({ - "id": str(p.id), - "end_user_id": str(p.end_user_id), - "perceptual_type": p.perceptual_type, - "file_path": p.file_path or "", - "file_name": p.file_name or "", - "file_ext": p.file_ext or "", - "summary": p.summary or "", - "keywords": content_meta.get("keywords", []), - "topic": content_meta.get("topic", ""), - "domain": content_meta.get("domain", ""), - "created_at": p.created_time.isoformat() if p.created_time else None, - "summary_embedding": summary_embedding, - }) - - result = await connector.execute_query( - PERCEPTUAL_NODE_SAVE, - perceptuals=flattened, - ) - created_uuids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") - return created_uuids - - except Exception as e: - print(f"Failed to save Perceptual nodes to Neo4j: {e}") - return None - - -async def add_perceptual_dialogue_edges( - perceptuals: list, - dialog_id: str, - connector: Neo4jConnector, -) -> Optional[List[str]]: - """Add edges between Perceptual nodes and Dialogue nodes. - - Args: - perceptuals: List of MemoryPerceptualModel objects - dialog_id: The dialogue ID (or ref_id) to link to - connector: Neo4j connector instance - - Returns: - List of created edge element IDs or None if failed - """ - if not perceptuals or not dialog_id: - return [] - - try: - edges = [] - for p in perceptuals: - edges.append({ - "perceptual_id": str(p.id), - "dialog_id": dialog_id, - "end_user_id": str(p.end_user_id), - "created_at": p.created_time.isoformat() if p.created_time else None, - }) - - result = await connector.execute_query( - PERCEPTUAL_DIALOGUE_EDGE_SAVE, - edges=edges, - ) - created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") - return created_ids - - except Exception as e: - print(f"Failed to save Perceptual-Dialogue edges: {e}") + logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 49dbe2a5..d70a30e9 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1003,60 +1003,69 @@ RETURN DISTINCT """ Graph_Node_query = """ - MATCH (n:MemorySummary) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 0 AS priority - LIMIT $limit +MATCH (n:MemorySummary) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Dialogue) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT 1 +MATCH (n:Dialogue) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT 1 - UNION ALL +UNION ALL - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT $limit +MATCH (n:Statement) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:ExtractedEntity) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 2 AS priority - LIMIT $limit +MATCH (n:ExtractedEntity) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Chunk) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 3 AS priority - LIMIT $limit +MATCH (n:Chunk) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority +LIMIT $limit - """ +UNION ALL +MATCH (n:Perceptual) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 4 AS priority + +""" # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 @@ -1340,19 +1349,19 @@ SET n += { topic: p.topic, domain: p.domain, created_at: p.created_at, + file_type: p.file_type, summary_embedding: p.summary_embedding } RETURN n.id AS uuid """ # 感知记忆与对话的关联边 -PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +PERCEPTUAL_CHUNK_EDGE_SAVE = """ UNWIND $edges AS edge MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) -MATCH (d:Dialogue {end_user_id: edge.end_user_id}) -WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id -MERGE (d)-[r:HAS_PERCEPTUAL]->(p) -SET r.end_user_id = edge.end_user_id, +MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id}) +MERGE (c)-[r:HAS_PERCEPTUAL]->(p) +ON CREATE SET r.end_user_id = edge.end_user_id, r.created_at = edge.created_at RETURN elementId(r) AS uuid """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..d78dcef6 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import ( StatementNode, ExtractedEntityNode, EntityEntityEdge, + PerceptualNode, + PerceptualEdge, ) import logging + logger = logging.getLogger(__name__) + + async def save_entities_and_relationships( - entity_nodes: List[ExtractedEntityNode], - entity_entity_edges: List[EntityEntityEdge], - connector: Neo4jConnector + entity_nodes: List[ExtractedEntityNode], + entity_entity_edges: List[EntityEntityEdge], + connector: Neo4jConnector ): """Save entities and their relationships using graph models""" all_entities = [entity.model_dump() for entity in entity_nodes] @@ -73,8 +78,8 @@ async def save_entities_and_relationships( async def save_chunk_nodes( - chunk_nodes: List[ChunkNode], - connector: Neo4jConnector + chunk_nodes: List[ChunkNode], + connector: Neo4jConnector ): """Save chunk nodes using graph models""" if not chunk_nodes: @@ -89,8 +94,8 @@ async def save_chunk_nodes( async def save_statement_chunk_edges( - statement_chunk_edges: List[StatementChunkEdge], - connector: Neo4jConnector + statement_chunk_edges: List[StatementChunkEdge], + connector: Neo4jConnector ): """Save statement-chunk edges using graph models""" if not statement_chunk_edges: @@ -118,8 +123,8 @@ async def save_statement_chunk_edges( async def save_statement_entity_edges( - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ): """Save statement-entity edges using graph models""" if not statement_entity_edges: @@ -142,7 +147,7 @@ async def save_statement_entity_edges( if all_se_edges: try: await connector.execute_query( - STATEMENT_ENTITY_EDGE_SAVE, + STATEMENT_ENTITY_EDGE_SAVE, relationships=all_se_edges ) except Exception: @@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List[ChunkNode], statement_nodes: List[StatementNode], entity_nodes: List[ExtractedEntityNode], + perceptual_nodes: List[PerceptualNode], entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], + perceptual_edges: List[PerceptualEdge], connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List of ChunkNode objects to save statement_nodes: List of StatementNode objects to save entity_nodes: List of ExtractedEntityNode objects to save + perceptual_nodes: List of PerceptualNode objects to save entity_edges: List of EntityEntityEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save + perceptual_edges: List of PerceptualEdge objects to save connector: Neo4j connector instance Returns: @@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) dialogue_uuids = [record["uuid"] async for record in result] results['dialogues'] = dialogue_uuids - print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") + logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") # 2. Save all chunk nodes in batch if chunk_nodes: @@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j( results['chunks'] = chunk_uuids logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") + if perceptual_nodes: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE + perceptual_data = [node.model_dump() for node in perceptual_nodes] + result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data) + perceptual_uuids = [record["uuid"] async for record in result] + results["perceptuals"] = perceptual_uuids + logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j") + # 3. Save all statement nodes in batch if statement_nodes: from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE @@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j( results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + if perceptual_edges: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE + perceptual_edge_data = [] + for edge in perceptual_edges: + print(edge.source, edge.target) + perceptual_edge_data.append({ + "perceptual_id": edge.source, + "chunk_id": edge.target, + "end_user_id": edge.end_user_id, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + }) + result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data) + perceptual_edges_uuids = [record["uuid"] async for record in result] + results['perceptual_chunk_edges'] = perceptual_edges_uuids + logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j") + return results try: @@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j( def schedule_clustering_after_write( - entity_nodes: List, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + entity_nodes: List, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 写入 Neo4j 成功后,调度后台聚类任务。 @@ -325,14 +358,15 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id)) async def _trigger_clustering( - new_entity_ids: List[str], - end_user_id: str, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + new_entity_ids: List[str], + end_user_id: str, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 058f082d..668a84a8 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase): updated_at: datetime.datetime api_keys: List["ModelApiKey"] = [] + @staticmethod + def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str: + if not key or len(key) <= prefix + suffix: + return "*" * len(key) + return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:] + @field_validator("api_keys", mode="after") @classmethod def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: @@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase): def _serialize_created_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("api_keys", when_used="json") + def _serialize_api_keys(self, api_keys: List["ModelApiKey"]): + result = [] + for api_key in api_keys: + data = api_key.model_dump() + data["api_key"] = self.mask_api_key(api_key.api_key) + result.append(data) + return result + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase): if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict): self.model_config_ids = [ mc.id for mc in self.model_configs - if hasattr(mc, 'id') - and not getattr(mc, 'is_composite', False) - and getattr(mc, 'name', None) == self.model_name + if hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name ] # 情况2:字典列表 elif isinstance(self.model_configs, list): self.model_config_ids = [ mc['id'] if isinstance(mc, dict) else mc.id for mc in self.model_configs - if ((isinstance(mc, dict) - and 'id' in mc + if ((isinstance(mc, dict) + and 'id' in mc and not mc.get('is_composite', False) - and mc.get('name') == self.model_name) or - (hasattr(mc, 'id') + and mc.get('name') == self.model_name) or + (hasattr(mc, 'id') and not getattr(mc, 'is_composite', False) and getattr(mc, 'name', None) == self.model_name)) ] @@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase): validate_assignment=True # 确保赋值触发校验 ) - @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel): """模型配置查询Schema""" type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") + capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)") is_active: Optional[bool] = Field(None, description="激活状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) @@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel): is_composite: Optional[bool] = Field(None, description="组合模型筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelMarketplace(BaseModel): """模型广场响应Schema""" llm_models: List[ModelConfig] = [] @@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel): class ModelBase(BaseModel): """基础模型Schema""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID name: str type: str @@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelInfo(BaseModel): """模型信息Schema""" model_name: str = Field(..., description="模型名称") @@ -336,4 +353,3 @@ class ModelInfo(BaseModel): is_omni: bool = Field(default=False, description="是否为omni模型") model_type: ModelType = Field(..., description="模型类型") capability: List[str] = Field(default_factory=list, description="模型能力列表") - diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 875f02bb..8bb6538d 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -274,7 +274,6 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], - file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -287,7 +286,6 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write - files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -348,15 +346,15 @@ class MemoryAgentService: raise ValueError(error_msg) perceptual_serivce = MemoryPerceptualService(db) - file_content = [] - for message in file_messages: + for message in messages: + message["file_content"] = [] for file in message["files"]: file_object = await perceptual_serivce.generate_perceptual_memory( end_user_id=end_user_id, memory_config=memory_config, file=FileInput(**file) ) - file_content.append(file_object) + message["file_content"].append((file_object, file["type"])) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: @@ -368,7 +366,6 @@ class MemoryAgentService: await write_neo4j( end_user_id=end_user_id, messages=messages, - file_content=file_content, memory_config=memory_config, ref_id='', language=language @@ -380,19 +377,23 @@ class MemoryAgentService: if deleted: logger.info( f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") - return self.writer_messages_deal( - "success", - start_time, - end_user_id, - config_id, - message_text, - { - "status": "success", - "data": messages, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name - } - ) + for message in messages: + message["file_content"] = [ + perceptual[0].file_path for perceptual in message["file_content"] + ] + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index d6c1de87..8255dbbe 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -317,11 +317,11 @@ class MemoryPerceptualService: stmt = select(FileMetadata).where( FileMetadata.id == file_id ) - file = self.db.execute(stmt).scalar_one_or_none() + file_obj = self.db.execute(stmt).scalar_one_or_none() - if file: - filename = file.file_name - file_ext = file.file_ext + if file_obj: + filename = file_obj.file_name + file_ext = file_obj.file_ext except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index fc749157..4617946b 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -297,9 +297,12 @@ async def run_pilot_extraction( chunk_nodes, statement_nodes, entity_nodes, + _, statement_chunk_edges, statement_entity_edges, entity_edges, + _, + _ ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index d5d19e0d..29516acc 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1888,7 +1888,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ "Chunk": ["content", "created_at"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], - "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段 + "Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"] } # 获取该节点类型的白名单字段 diff --git a/api/app/tasks.py b/api/app/tasks.py index 8afb2194..f243eac3 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,14 +1080,12 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, - file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write - file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1099,9 +1097,6 @@ def write_message_task( Raises: Exception on failure """ - if file_messages is None: - file_messages = [] - logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"config_id={config_id} (type: {type(config_id).__name__}), " @@ -1146,7 +1141,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result @@ -2617,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } -@celery_app.task( - name="app.tasks.write_perceptual_memory", - bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=3600, - soft_time_limit=3300, -) -def write_perceptual_memory( - self, - end_user_id: str, - model_api_config: dict, - file_type: str, - file_url: str, - file_message: dict -): - """ - Write perceptual memory for a user into PostgreSQL and Neo4j. - - This task generates or updates the user's perceptual memory - in the backend databases. It is intended to be executed asynchronously - via Celery. - - Args: - end_user_id (uuid.UUID): The unique identifier of the end user. - model_api_config (ModelInfo): API configuration for the model - used to generate perceptual memory. - file_type (str): The file type - file_url (url): The url of file - file_message (dict): The file message containing details about the file - to be processed. - - Returns: - None - """ - file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() - set_asyncio_event_loop() - with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): - model_info = ModelInfo(**model_api_config) - with get_db_context() as db: - memory_perceptual_service = MemoryPerceptualService(db) - return asyncio.run(memory_perceptual_service.generate_perceptual_memory( - end_user_id, - model_info, - file_type, - file_url, - file_message, - )) - - # ============================================================================= # 社区聚类补全任务(触发型) # ============================================================================= From de6e2f54d2d773969c36a3658791668171034fc4 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 14:30:07 +0800 Subject: [PATCH 006/120] fix(perceptual): prevent errors when writing unsupported modalities --- api/app/services/memory_agent_service.py | 2 ++ api/app/services/memory_api_service.py | 4 ---- api/app/services/memory_perceptual_service.py | 6 ++++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8bb6538d..3cfcc1d6 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -354,6 +354,8 @@ class MemoryAgentService: memory_config=memory_config, file=FileInput(**file) ) + if file_object is None: + continue message["file_content"].append((file_object, file["type"])) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 9a0fb8ed..9282fc28 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -131,7 +131,6 @@ class MemoryAPIService: message: str, config_id: str, storage_type: str = "neo4j", - files: Optional[list]=None, user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -154,8 +153,6 @@ class MemoryAPIService: ResourceNotFoundException: If end_user not found BusinessException: If end_user not in authorized workspace or write fails """ - if files is None: - files = list() logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") # Validate end_user exists and belongs to workspace @@ -175,7 +172,6 @@ class MemoryAPIService: db=self.db, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "", - files=files ) logger.info(f"Memory write successful for end_user: {end_user_id}") diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8255dbbe..effceda7 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -277,8 +277,10 @@ class MemoryPerceptualService: file_message = await multimodel_service.process_files( files=[file] ) - if file_message: - file_message = file_message[0] + if not file_message: + logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}") + return None + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: From 3dc863cabf29d08c02612253f3569a18fe6ffadc Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 15:16:16 +0800 Subject: [PATCH 007/120] feat(memory): add audio_id, vision_id and video_id fields to memory configuration --- api/app/schemas/memory_storage_schema.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index c8abbc46..711b6de9 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -264,6 +264,9 @@ class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用 class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 config_id: Union[uuid.UUID, int, str] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") + audio_id: Optional[str] = Field(None, description="语音模型ID") + vision_id: Optional[str] = Field(None, description="视觉模型ID") + video_id: Optional[str] = Field(None, description="视频模型ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") enable_llm_dedup_blockwise: Optional[bool] = None From b739b032d940ac77a2d8b269bdc77beb956f4e14 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 15:17:01 +0800 Subject: [PATCH 008/120] fix(workflow): remove edges for unreachable nodes in graph --- api/app/core/workflow/engine/graph_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 674c45d0..c5cf3324 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -390,6 +390,8 @@ class GraphBuilder: for edge in self.edges: source = edge.get("source") target = edge.get("target") + if source not in self.reachable_nodes or target not in self.reachable_nodes: + continue condition = edge.get("condition") edge_type = edge.get("type") From 8ff1c6bd08813ff4a5e35afc457d01bfe474b8a1 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 24 Mar 2026 15:33:09 +0800 Subject: [PATCH 009/120] [add] migratinon script --- .../versions/e28bcc212da5_202603241530.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 api/migrations/versions/e28bcc212da5_202603241530.py diff --git a/api/migrations/versions/e28bcc212da5_202603241530.py b/api/migrations/versions/e28bcc212da5_202603241530.py new file mode 100644 index 00000000..00173522 --- /dev/null +++ b/api/migrations/versions/e28bcc212da5_202603241530.py @@ -0,0 +1,34 @@ +"""202603241530 + +Revision ID: e28bcc212da5 +Revises: 05a681a6ca93 +Create Date: 2026-03-24 15:32:14.461480 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e28bcc212da5' +down_revision: Union[str, None] = '05a681a6ca93' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('memory_config', sa.Column('vision_id', sa.String(), nullable=True, comment='视觉模型配置ID')) + op.add_column('memory_config', sa.Column('audio_id', sa.String(), nullable=True, comment='语音模型配置ID')) + op.add_column('memory_config', sa.Column('video_id', sa.String(), nullable=True, comment='视频模型配置ID')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('memory_config', 'video_id') + op.drop_column('memory_config', 'audio_id') + op.drop_column('memory_config', 'vision_id') + # ### end Alembic commands ### From 1c49e3c16721dc72f39e6f887b4dde2e05c992cf Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 17:12:37 +0800 Subject: [PATCH 010/120] feat(workflow): use lightweight deque for streaming scheduler output queue to reduce read/write overhead --- .../engine/stream_output_coordinator.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index 6685a49e..dcc92fdb 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -3,7 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 15:11 import re -from queue import Queue +from collections import deque from typing import AsyncGenerator from pydantic import BaseModel, Field, PrivateAttr @@ -256,7 +256,7 @@ class StreamOutputCoordinator: def __init__(self): self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None - self.output_queue: Queue = Queue() + self.output_queue: deque[str] = deque() self.processed_outputs = [] def initialize_end_outputs( @@ -266,7 +266,7 @@ class StreamOutputCoordinator: self.end_outputs = end_node_map self.processed_outputs = [] self.activate_end = None - self.output_queue = Queue() + self.output_queue = deque() @property def current_activate_end_info(self): @@ -296,13 +296,13 @@ class StreamOutputCoordinator: scope (str): The node ID or scope that has completed execution. status (str | None): Optional status of the node (used for branch/control nodes). """ - for node in self.end_outputs.keys(): + for node in self.end_outputs: self.end_outputs[node].update_activate(scope, status) if self.end_outputs[node].activate and node not in self.processed_outputs: - self.output_queue.put(node) + self.output_queue.append(node) self.processed_outputs.append(node) - if self.activate_end is None and not self.output_queue.empty(): - self.activate_end = self.output_queue.get_nowait() + if self.activate_end is None and self.output_queue: + self.activate_end = self.output_queue.popleft() async def emit_activate_chunk( self, @@ -414,8 +414,8 @@ class StreamOutputCoordinator: async for msg_event in self.emit_activate_chunk(variable_pool, force=True): yield msg_event - if not self.output_queue.empty(): - self.activate_end = self.output_queue.get_nowait() + if self.output_queue: + self.activate_end = self.output_queue.popleft() # Move to next active End node if current one is done if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] From 04c54081c811e0963c7c1f43cdc0b76afa09482f Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 24 Mar 2026 17:29:38 +0800 Subject: [PATCH 011/120] [add] celery support rbmq --- api/app/celery_app.py | 22 +++++++++++++--------- api/app/core/config.py | 4 ++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 864bee4a..58c89f8f 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,5 +1,6 @@ import os import platform +import re from datetime import timedelta from urllib.parse import quote @@ -11,21 +12,24 @@ from app.core.logging_config import get_logger logger = get_logger(__name__) + +def _mask_url(url: str) -> str: + """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" + return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) + # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') # 创建 Celery 应用实例 -# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定) -# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定) +# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议), +# 未配置则回退到 Redis 方案 +# backend: 结果存储(使用 Redis) # NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md -# Build canonical broker/backend URLs and force them into os.environ so that -# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first) -# cannot be overridden by stray env vars. -# See: https://github.com/celery/celery/issues/4284 -_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" +_broker_url = os.getenv("CELERY_BROKER_URL") or \ + f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url @@ -45,8 +49,8 @@ celery_app = Celery( logger.info( "Celery app initialized", extra={ - "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"), - "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"), + "broker": _mask_url(_broker_url), + "backend": _mask_url(_backend_url), }, ) # Default queue for unrouted tasks diff --git a/api/app/core/config.py b/api/app/core/config.py index 4a944557..64c5520e 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -231,8 +231,8 @@ class Settings: # Celery configuration (internal) # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 # 详见 docs/celery-env-bug-report.md - # 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离 - # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰 + # 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离 + # 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) From 522eb569f15dcc197fa25d0b79e7a123ba358458 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Mar 2026 18:10:45 +0800 Subject: [PATCH 012/120] fix(memory): fix undefined logger causing logging errors in memory module --- api/app/services/memory_agent_service.py | 2 +- api/app/services/memory_perceptual_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 3cfcc1d6..e5c34492 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -348,7 +348,7 @@ class MemoryAgentService: perceptual_serivce = MemoryPerceptualService(db) for message in messages: message["file_content"] = [] - for file in message["files"]: + for file in (message.get("files") or []): file_object = await perceptual_serivce.generate_perceptual_memory( end_user_id=end_user_id, memory_config=memory_config, diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index effceda7..3ee238e2 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -278,7 +278,7 @@ class MemoryPerceptualService: files=[file] ) if not file_message: - logger.warning(f"Unsupport file type {file}, model capability: {model_config.capability}") + business_logger.warning(f"Unsupported file type {file}, model capability: {model_config.capability}") return None file_message = file_message[0] try: From 54cff5861a977d49737b3fb1fddd9652c3537a14 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 25 Mar 2026 11:45:49 +0800 Subject: [PATCH 013/120] feat(model): add volcano model --- .../core/memory/llm_tools/openai_embedder.py | 17 +- api/app/core/models/__init__.py | 5 +- api/app/core/models/base.py | 6 +- api/app/core/models/embedding.py | 181 ++++++++- api/app/core/models/generation.py | 345 ++++++++++++++++++ .../core/models/scripts/volcano_models.yaml | 334 +++++++++++++++++ .../vdb/elasticsearch/elasticsearch_vector.py | 34 +- api/app/models/models_model.py | 5 +- api/app/repositories/model_repository.py | 1 - api/app/services/generation_service.py | 164 +++++++++ api/app/services/model_service.py | 61 +++- api/app/services/multimodal_service.py | 1 + api/pyproject.toml | 1 + 13 files changed, 1122 insertions(+), 33 deletions(-) create mode 100644 api/app/core/models/generation.py create mode 100644 api/app/core/models/scripts/volcano_models.yaml create mode 100644 api/app/services/generation_service.py diff --git a/api/app/core/memory/llm_tools/openai_embedder.py b/api/app/core/memory/llm_tools/openai_embedder.py index 2d6fccbc..6ae87887 100644 --- a/api/app/core/memory/llm_tools/openai_embedder.py +++ b/api/app/core/memory/llm_tools/openai_embedder.py @@ -2,6 +2,7 @@ OpenAI Embedder 客户端实现 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 +自动支持火山引擎的多模态 Embedding。 """ from typing import List @@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import ( ) from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings +from app.models.models_model import ModelProvider logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient): - 批量文本嵌入 - 自动重试机制 - 错误处理 + - 火山引擎多模态 Embedding(自动识别) """ def __init__(self, model_config: RedBearModelConfig): @@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient): """ super().__init__(model_config) - # 初始化 RedBearEmbeddings 模型 + # 初始化 RedBearEmbeddings(自动支持火山引擎多模态) self.model = RedBearEmbeddings( RedBearModelConfig( model_name=self.model_name, @@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient): timeout=self.timeout, ) ) + self.is_multimodal = self.model.is_multimodal_supported() - logger.info("OpenAI Embedder 客户端初始化完成") + logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})") async def response( self, @@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient): return [] # 生成嵌入向量 - embeddings = await self.model.aembed_documents(texts) + if self.is_multimodal: + # 火山引擎多模态 Embedding + embeddings = await self.model.aembed_multimodal( + [{"type": "text", "text": text} for text in texts] + ) + else: + # 普通 Embedding + embeddings = await self.model.aembed_documents(texts) logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") return embeddings diff --git a/api/app/core/models/__init__.py b/api/app/core/models/__init__.py index f54afc08..f98d073f 100644 --- a/api/app/core/models/__init__.py +++ b/api/app/core/models/__init__.py @@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto from .llm import RedBearLLM from .embedding import RedBearEmbeddings from .rerank import RedBearRerank +from .generation import RedBearImageGenerator, RedBearVideoGenerator __all__ = [ "RedBearModelConfig", @@ -9,5 +10,7 @@ __all__ = [ "RedBearEmbeddings", "RedBearRerank", "RedBearModelFactory", - "get_provider_llm_class" + "get_provider_llm_class", + "RedBearImageGenerator", + "RedBearVideoGenerator" ] \ No newline at end of file diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 4a453c6b..80117f27 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -67,7 +67,7 @@ class RedBearModelFactory: **config.extra_params } - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 # 这样可以分别控制连接超时和读取超时 import httpx @@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: return ChatOpenAI - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]: if type == ModelType.LLM: return OpenAI elif type == ModelType.CHAT: return ChatOpenAI + else: + raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) elif provider == ModelProvider.DASHSCOPE: return ChatTongyi elif provider == ModelProvider.OLLAMA: diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py index 16af2567..9ccf53de 100644 --- a/api/app/core/models/embedding.py +++ b/api/app/core/models/embedding.py @@ -1,23 +1,190 @@ -from typing import Any, Dict, List, Optional, TypeVar, Callable +from typing import Any, Dict, List, Optional, Union from langchain_core.embeddings import Embeddings -from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory +from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory +from app.models.models_model import ModelProvider + class RedBearEmbeddings(Embeddings): - """Embedding → 完全符合 LangChain Embeddings""" + """统一的 Embedding 类,自动支持多模态(根据 provider 判断)""" + def __init__(self, config: RedBearModelConfig): - self._model = self._create_model(config) self._config = config + self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO + + if self._is_volcano: + # 火山引擎使用 Ark SDK + self._client = self._create_volcano_client(config) + self._model = None + else: + # 其他 provider 使用 LangChain + self._model = self._create_model(config) + self._client = None def _create_model(self, config: RedBearModelConfig) -> Embeddings: - """根据配置创建模型""" + """根据配置创建 LangChain 模型""" embedding_class = get_provider_embedding_class(config.provider) model_params = RedBearModelFactory.get_model_params(config) return embedding_class(**model_params) + + def _create_volcano_client(self, config: RedBearModelConfig): + """创建火山引擎客户端""" + from volcenginesdkarkruntime import Ark + return Ark(api_key=config.api_key, base_url=config.base_url) + # ==================== LangChain 标准接口 ==================== + def embed_documents(self, texts: list[str]) -> list[list[float]]: - return self._model.embed_documents(texts) + """批量文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + contents = [{"type": "text", "text": text} for text in texts] + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + encoding_format="float" + ) + return [response.data.embedding] + else: + # 其他 provider + return self._model.embed_documents(texts) def embed_query(self, text: str) -> List[float]: - return self._model.embed_query(text) + """单个文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + result = self.embed_documents([text]) + return result[0] if result else [] + else: + # 其他 provider + return self._model.embed_query(text) + + # ==================== 多模态扩展方法 ==================== + + def embed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """ + 多模态向量化(仅火山引擎支持) + + Args: + contents: 内容列表,格式: + - 文本: {"type": "text", "text": "..."} + - 图片: {"type": "image_url", "image_url": {"url": "..."}} + - 视频: {"type": "video_url", "video_url": {"url": "..."}} + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + if not self._is_volcano: + raise NotImplementedError( + f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + **kwargs + ) + return [item.embedding for item in response.data] + + async def aembed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """异步多模态向量化""" + # 火山引擎 SDK 暂不支持异步,使用同步方法 + return self.embed_multimodal(contents, **kwargs) + + def embed_text(self, text: str, **kwargs) -> List[float]: + """文本向量化(便捷方法)""" + if self._is_volcano: + result = self.embed_multimodal( + [{"type": "text", "text": text}], + **kwargs + ) + return result[0] if result else [] + else: + return self.embed_query(text) + + def embed_image(self, image_url: str, **kwargs) -> List[float]: + """图片向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "image_url", "image_url": {"url": image_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_video(self, video_url: str, **kwargs) -> List[float]: + """视频向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "video_url", "video_url": {"url": video_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_batch( + self, + items: List[Union[str, Dict[str, Any]]], + **kwargs + ) -> List[List[float]]: + """ + 批量向量化(支持混合类型) + + Args: + items: 可以是字符串列表或内容字典列表 + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + # 如果全是字符串,使用标准方法 + if all(isinstance(item, str) for item in items): + return self.embed_documents(items) + + # 如果包含字典,需要多模态支持 + if not self._is_volcano: + raise NotImplementedError( + f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + # 标准化输入格式 + contents = [] + for item in items: + if isinstance(item, str): + contents.append({"type": "text", "text": item}) + elif isinstance(item, dict): + contents.append(item) + else: + raise ValueError(f"不支持的输入类型: {type(item)}") + + return self.embed_multimodal(contents, **kwargs) + + # ==================== 工具方法 ==================== + + def is_multimodal_supported(self) -> bool: + """检查是否支持多模态""" + return self._is_volcano + + def get_provider(self) -> str: + """获取 provider""" + return self._config.provider + + +# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容 +RedBearMultimodalEmbeddings = RedBearEmbeddings diff --git a/api/app/core/models/generation.py b/api/app/core/models/generation.py new file mode 100644 index 00000000..98b23fbf --- /dev/null +++ b/api/app/core/models/generation.py @@ -0,0 +1,345 @@ +""" +图片和视频生成模型封装 + +支持的 Provider: +- Volcano (火山引擎): 使用 volcenginesdkarkruntime +- OpenAI: 使用 openai SDK +""" +from typing import Any, Dict, Optional + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.images.images import ( + SequentialImageGenerationOptions, + ContentGenerationTool, + OptimizePromptOptions +) + +from app.core.models.base import RedBearModelConfig +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelProvider + + +class RedBearImageGenerator: + """图片生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + # elif provider == ModelProvider.OPENAI: + # from openai import OpenAI + # return OpenAI(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的图片生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + sequential_image_generation: Optional[str] = None, + sequential_image_generation_options: Optional[Dict] = None, + tools: Optional[list] = None, + optimize_prompt_options: Optional[Dict] = None, + stream: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + prompt: 提示词 + image: 参考图片URL或URL列表(图文生图/多图融合) + size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素) + n: 生成数量 + output_format: 输出格式,如 "png", "jpg" + response_format: 返回格式,"url" 或 "b64_json" + watermark: 是否添加水印 + sequential_image_generation: 组图生成模式,"auto" 或 "disabled" + sequential_image_generation_options: 组图生成选项,如 {"max_images": 4} + tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图 + optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"} + stream: 是否使用流式生成 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = { + "model": self._config.model_name, + "prompt": prompt, + "size": size, + "output_format": output_format, + "response_format": response_format, + "watermark": watermark, + } + + if image is not None: + params["image"] = image + + if sequential_image_generation: + params["sequential_image_generation"] = sequential_image_generation + if sequential_image_generation_options: + params["sequential_image_generation_options"] = SequentialImageGenerationOptions( + **sequential_image_generation_options + ) + + if tools: + params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools] + + if optimize_prompt_options: + params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options) + + if stream: + params["stream"] = True + + params.update(kwargs) + response = self._client.images.generate(**params) + + # elif provider == ModelProvider.OPENAI: + # response = self._client.images.generate( + # model=self._config.model_name, + # prompt=prompt, + # size=size, + # n=n, + # **kwargs + # ) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + **kwargs + ) -> Dict[str, Any]: + """异步生成图片""" + return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs) + + +class RedBearVideoGenerator: + """视频生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的视频生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image_url: Optional[str] = None, + first_frame_url: Optional[str] = None, + last_frame_url: Optional[str] = None, + reference_images: Optional[list] = None, + draft_task_id: Optional[str] = None, + duration: Optional[int] = None, + frames: Optional[int] = None, + ratio: Optional[str] = None, + resolution: Optional[str] = None, + generate_audio: bool = False, + watermark: bool = False, + camera_fixed: bool = False, + seed: Optional[int] = None, + return_last_frame: bool = False, + service_tier: str = "default", + execution_expires_after: Optional[int] = None, + draft: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + prompt: 提示词 + image_url: 首帧图片URL(图生视频-基于首帧) + first_frame_url: 首帧图片URL(图生视频-基于首尾帧) + last_frame_url: 尾帧图片URL(图生视频-基于首尾帧) + reference_images: 参考图片URL列表(图生视频-基于参考图) + draft_task_id: Draft任务ID(基于Draft生成正式视频) + duration: 视频时长(秒),与frames二选一 + frames: 视频帧数,与duration二选一 + ratio: 视频比例,如 "16:9", "9:16", "adaptive" + resolution: 视频分辨率,如 "720p", "1080p" + generate_audio: 是否生成音频 + watermark: 是否添加水印 + camera_fixed: 是否固定镜头 + seed: 随机种子 + return_last_frame: 是否返回最后一帧 + service_tier: 服务层级,"default" 或 "flex"(离线推理) + execution_expires_after: 任务过期时间(秒) + draft: 是否生成样片 + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID,需要轮询获取结果) + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + content = [{"type": "text", "text": prompt}] + + if draft_task_id: + content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}] + else: + if image_url: + content.append({"type": "image_url", "image_url": {"url": image_url}}) + + if first_frame_url: + content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"}) + if last_frame_url: + content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"}) + + if reference_images: + for ref_url in reference_images: + content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"}) + + params = {"model": self._config.model_name, "content": content, "watermark": watermark} + + if duration: + params["duration"] = duration + if frames: + params["frames"] = frames + if ratio: + params["ratio"] = ratio + if resolution: + params["resolution"] = resolution + if generate_audio: + params["generate_audio"] = generate_audio + if camera_fixed: + params["camera_fixed"] = camera_fixed + if seed is not None: + params["seed"] = seed + if return_last_frame: + params["return_last_frame"] = return_last_frame + if service_tier != "default": + params["service_tier"] = service_tier + if execution_expires_after: + params["execution_expires_after"] = execution_expires_after + if draft: + params["draft"] = draft + + params.update(kwargs) + response = self._client.content_generation.tasks.create(**params) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image_url: Optional[str] = None, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """异步生成视频""" + return self.generate(prompt, image_url=image_url, duration=duration, **kwargs) + + def get_task_status(self, task_id: str) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + task_id: 任务ID + + Returns: + 任务状态信息 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + response = self._client.content_generation.tasks.get(task_id=task_id) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + async def aget_task_status(self, task_id: str) -> Dict[str, Any]: + """异步查询任务状态""" + return self.get_task_status(task_id) + + def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]: + """ + 查询视频生成任务列表 + + Args: + page_size: 每页数量 + status: 任务状态筛选,如 "succeeded", "failed", "pending" + **kwargs: 其他参数 + + Returns: + 任务列表 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = {"page_size": page_size} + if status: + params["status"] = status + params.update(kwargs) + response = self._client.content_generation.tasks.list(**params) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def delete_task(self, task_id: str) -> None: + """ + 删除或取消视频生成任务 + + Args: + task_id: 任务ID + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + self._client.content_generation.tasks.delete(task_id=task_id) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml new file mode 100644 index 00000000..24609f5a --- /dev/null +++ b/api/app/core/models/scripts/volcano_models.yaml @@ -0,0 +1,334 @@ +provider: volcano +models: +# Doubao-Seed 2.0 系列 +- name: doubao-seed-2-0-pro-260215 + type: chat + provider: volcano + description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-lite-260215 + type: chat + provider: volcano + description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-mini-260215 + type: chat + provider: volcano + description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-code-preview-260215 + type: chat + provider: volcano + description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +# Doubao-Seed 1.x 系列 +- name: doubao-seed-1-8-251228 + type: chat + provider: volcano + description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-251015 + type: chat + provider: volcano + description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-lite-251015 + type: chat + provider: volcano + description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-flash-250828 + type: chat + provider: volcano + description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-code-preview-251028 + type: chat + provider: volcano + description: 面向Agentic编程任务进行了深度优化。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +- name: doubao-seed-1-6-vision-250815 + type: chat + provider: volcano + description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +# Doubao 1.5 系列 +- name: doubao-1-5-vision-pro-32k-250115 + type: chat + provider: volcano + description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +- name: doubao-1-5-pro-32k-250115 + type: chat + provider: volcano + description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-1-5-lite-32k-250115 + type: chat + provider: volcano + description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +# Doubao-Seedance 视频生成系列 +- name: doubao-seedance-1-5-pro-251215 + type: video + provider: volcano + description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-250528 + type: video + provider: volcano + description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-fast-251015 + type: video + provider: volcano + description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-lite-i2v-250428 + type: video + provider: volcano + description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + - 图生视频 + logo: volcano + +- name: doubao-seedance-1-0-lite-t2v-250428 + type: video + provider: volcano + description: 基于文本提示词生成视频 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 视频生成 + - 文生视频 + logo: volcano + +# Doubao-Seedream 图像生成系列 +- name: doubao-seedream-5-0-260128 + type: image + provider: volcano + description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-5-251128 + type: image + provider: volcano + description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-0-250828 + type: image + provider: volcano + description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-3-0-t2i-250415 + type: image + provider: volcano + description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 图像生成 + - 文生图 + logo: volcano + +# Doubao 翻译系列 +- name: doubao-seed-translation-250915 + type: chat + provider: volcano + description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 翻译模型 + logo: volcano + +# Doubao Embedding 系列 +- name: doubao-embedding-vision-251215 + type: embedding + provider: volcano + description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 向量模型 + - 多模态模型 + logo: volcano diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 198d1473..386920e0 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel): class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): super().__init__(index_name.lower()) - # self.embeddings = XinferenceEmbeddings( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port - # model_uid="bge-m3" # replace model_uid with the model UID return from launching the model - # ) - # Remove debug printing to avoid leaking sensitive information - # print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base) + + # 初始化 Embedding 模型(自动支持火山引擎多模态) self.embeddings = RedBearEmbeddings(RedBearModelConfig( model_name=embedding_config.model_name, provider=embedding_config.provider, api_key=embedding_config.api_key, base_url=embedding_config.api_base )) - # self.reranker = XinferenceRerank( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), - # model_uid="bge-reranker-large" - # ) - # Remove debug printing to avoid leaking sensitive information - # print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base) + self.is_multimodal_embedding = self.embeddings.is_multimodal_supported() + self.reranker = RedBearRerank(RedBearModelConfig( model_name=reranker_config.model_name, provider=reranker_config.provider, @@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector): def add_chunks(self, chunks: list[DocumentChunk], **kwargs): # 实现 Elasticsearch 保存向量 texts = [chunk.page_content for chunk in chunks] - embeddings = self.embeddings.embed_documents(list(texts)) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + embeddings = self.embeddings.embed_batch(texts) + else: + embeddings = self.embeddings.embed_documents(list(texts)) self.create(chunks, embeddings, **kwargs) def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): @@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector): updated count. """ indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - chunk.vector = self.embeddings.embed_query(chunk.page_content) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + chunk.vector = self.embeddings.embed_text(chunk.page_content) + else: + chunk.vector = self.embeddings.embed_query(chunk.page_content) body = { "script": { @@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: """Search the nearest neighbors to a vector.""" - query_vector = self.embeddings.embed_query(query) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + query_vector = self.embeddings.embed_text(query) + else: + query_vector = self.embeddings.embed_query(query) top_k = kwargs.get("top_k", 1024) score_threshold = float(kwargs.get("score_threshold") or 0.3) indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcef..a16a4073 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -26,9 +26,9 @@ class ModelType(StrEnum): RERANK = "rerank" # TTS = "tts" # SPEECH2TEXT = "speech2text" - # IMAGE = "image" + IMAGE = "image" # AUDIO = "audio" - # VISION = "vision" + VIDEO = "video" class ModelProvider(StrEnum): @@ -45,6 +45,7 @@ class ModelProvider(StrEnum): XINFERENCE = "xinference" GPUSTACK = "gpustack" BEDROCK = "bedrock" + VOLCANO = "volcano" COMPOSITE = "composite" diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f49227d3..90ada6fa 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -435,7 +435,6 @@ class ModelConfigRepository: ModelConfig.is_public ), ModelConfig.provider == provider, - ModelConfig.is_active, ~ModelConfig.is_composite ) ).all() diff --git a/api/app/services/generation_service.py b/api/app/services/generation_service.py new file mode 100644 index 00000000..e7800ef6 --- /dev/null +++ b/api/app/services/generation_service.py @@ -0,0 +1,164 @@ +""" +图片和视频生成服务 + +提供统一的生成接口,支持多种 Provider +""" +from typing import Dict, Any, Optional +from sqlalchemy.orm import Session +import uuid + +from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelType +from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository +from app.services.model_service import ModelApiKeyService + + +class GenerationService: + """生成服务""" + + def __init__(self, db: Session): + self.db = db + + async def generate_image( + self, + model_config_id: str, + prompt: str, + size: Optional[str] = "1024x1024", + n: int = 1, + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + size: 图片尺寸 + n: 生成数量 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.IMAGE: + raise BusinessException( + f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成图片 + generator = RedBearImageGenerator(config) + result = await generator.agenerate(prompt, size, n, **kwargs) + + return result + + async def generate_video( + self, + model_config_id: str, + prompt: str, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + duration: 视频时长(秒) + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID) + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.VIDEO: + raise BusinessException( + f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成视频 + generator = RedBearVideoGenerator(config) + result = await generator.agenerate(prompt, duration, **kwargs) + + return result + + async def get_video_task_status( + self, + model_config_id: str, + task_id: str + ) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + model_config_id: 模型配置ID + task_id: 任务ID + + Returns: + 任务状态信息 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 查询任务状态 + generator = RedBearVideoGenerator(config) + result = await generator.aget_task_status(task_id) + + return result diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index a7398504..b98674ba 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -154,10 +154,17 @@ class ModelConfigService: } elif model_type_lower == "embedding": - # Embedding 模型验证(在线程中运行同步方法) + # Embedding 模型验证 + # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + + # 火山引擎使用 embed_batch,其他使用 embed_documents + if provider.lower() == "volcano": + vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) + else: + vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + elapsed_time = time.time() - start_time return { @@ -193,6 +200,56 @@ class ModelConfigService: }, "error": None } + + elif model_type_lower == "image": + # 图片生成模型验证 + from app.core.models.generation import RedBearImageGenerator + + generator = RedBearImageGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda", + size="2K" + ) + elapsed_time = time.time() - start_time + logger.info(f"成功生成图片,结果: {result}") + + return { + "valid": True, + "message": "图片生成模型配置验证成功", + "response": f"成功生成图片,结果: {result}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda"), + "image_count": 1 + }, + "error": None + } + + elif model_type_lower == "video": + # 视频生成模型验证 + from app.core.models.generation import RedBearVideoGenerator + + generator = RedBearVideoGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda playing in bamboo forest", + duration=5 + ) + elapsed_time = time.time() - start_time + + # 视频生成是异步任务,返回任务ID + task_id = result.get("task_id") if isinstance(result, dict) else None + + return { + "valid": True, + "message": "视频生成模型配置验证成功", + "response": f"成功创建视频生成任务,任务ID: {task_id}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda playing in bamboo forest"), + "task_id": task_id + }, + "error": None + } else: return { diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 6cb0a7f0..583a33b8 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -297,6 +297,7 @@ PROVIDER_STRATEGIES = { "bedrock": BedrockFormatStrategy, "anthropic": BedrockFormatStrategy, "openai": OpenAIFormatStrategy, + "volcano": OpenAIFormatStrategy, } diff --git a/api/pyproject.toml b/api/pyproject.toml index e6fddea8..8ced574c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -147,6 +147,7 @@ dependencies = [ "modelscope>=1.34.0", "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", "python-magic-bin>=0.4.14; sys_platform=='win32'", + "volcengine-python-sdk[ark]==5.0.19" ] [tool.pytest.ini_options] From e86d679ae5696c1bc4cb08fc326a2eae00df5be2 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 25 Mar 2026 12:37:16 +0800 Subject: [PATCH 014/120] perf(workflow): improve performance of workflow analysis algorithms, fix typos, adjust debug log levels --- .../core/workflow/adapters/base_adapter.py | 10 +- .../core/workflow/adapters/dify/converter.py | 64 ++++++------- .../workflow/adapters/dify/dify_adapter.py | 19 ++-- api/app/core/workflow/adapters/errors.py | 20 ++-- .../memory_bear/memory_bear_adapter.py | 12 +-- .../memory_bear/memory_bear_converter.py | 4 +- api/app/core/workflow/engine/graph_builder.py | 19 ++-- .../core/workflow/engine/result_builder.py | 17 +++- .../core/workflow/engine/runtime_schema.py | 3 + api/app/core/workflow/executor.py | 95 ++----------------- api/app/core/workflow/nodes/agent/node.py | 2 - api/app/core/workflow/nodes/base_node.py | 10 +- api/app/core/workflow/nodes/end/config.py | 4 +- api/app/core/workflow/nodes/end/node.py | 9 -- api/app/core/workflow/nodes/enums.py | 2 +- .../workflow/nodes/http_request/config.py | 4 +- .../core/workflow/nodes/http_request/node.py | 2 +- api/app/core/workflow/nodes/start/node.py | 7 +- .../{file_processer.py => file_processor.py} | 0 api/app/core/workflow/validator.py | 22 ++--- .../workflow/variable/variable_objects.py | 2 +- api/app/services/workflow_import_service.py | 4 +- 22 files changed, 127 insertions(+), 204 deletions(-) rename api/app/core/workflow/utils/{file_processer.py => file_processor.py} (100%) diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py index 49321b89..2e24d085 100644 --- a/api/app/core/workflow/adapters/base_adapter.py +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -9,7 +9,7 @@ from typing import Any from pydantic import BaseModel, Field -from app.core.workflow.adapters.errors import ExceptionDefineition +from app.core.workflow.adapters.errors import ExceptionDefinition from app.schemas.workflow_schema import ( EdgeDefinition, NodeDefinition, @@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class WorkflowImportResult(BaseModel): @@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class BasePlatformAdapter(ABC): diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 467beb07..4fa9508b 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -9,9 +9,9 @@ from urllib.parse import quote from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.errors import ( - UnsupportVariableType, - UnknowModelWarning, - ExceptionDefineition, + UnsupportedVariableType, + UnknownModelWarning, + ExceptionDefinition, ExceptionType ) from app.core.workflow.nodes.assigner.config import AssignmentItem @@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import ( HttpFormData, HttpTimeOutConfig, HttpRetryConfig, - HttpErrorDefaultTamplete, + HttpErrorDefaultTemplate, HttpErrorHandleConfig ) from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig @@ -108,7 +108,7 @@ class DifyConverter(BaseConverter): try: return config.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, @@ -138,7 +138,7 @@ class DifyConverter(BaseConverter): var_selector = mapping.get(var_selector, var_selector) return var_selector - def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + def _process_list_variable_literal(self, variable_selector: list) -> str | None: if not self.process_var_selector(".".join(variable_selector)): return None return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" @@ -269,7 +269,7 @@ class DifyConverter(BaseConverter): var_type = self.variable_type_map(var["type"]) if not var_type: self.errors.append( - UnsupportVariableType( + UnsupportedVariableType( scope=node["id"], name=var["variable"], var_type=var["type"], @@ -281,7 +281,7 @@ class DifyConverter(BaseConverter): if var_type in ["file", "array[file]"]: self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.VARIABLE, node_id=node["id"], node_name=node_data["title"], @@ -311,7 +311,7 @@ class DifyConverter(BaseConverter): def convert_question_classifier_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -327,7 +327,7 @@ class DifyConverter(BaseConverter): ) result = QuestionClassifierNodeConfig.model_construct( - input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), + input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), categories=categories, ).model_dump() @@ -337,13 +337,13 @@ class DifyConverter(BaseConverter): def convert_llm_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") ) ) - context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + context = self._process_list_variable_literal(node_data["context"]["variable_selector"]) memory = MemoryWindowSetting( enable=bool(node_data.get("memory")), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), @@ -367,7 +367,7 @@ class DifyConverter(BaseConverter): ) ) vision = node_data["vision"]["enabled"] - vision_input = self._process_list_variable_litearl( + vision_input = self._process_list_variable_literal( node_data["vision"]["configs"]["variable_selector"] ) if vision else None result = LLMNodeConfig.model_construct( @@ -433,7 +433,7 @@ class DifyConverter(BaseConverter): conditions.append( LoopConditionDetail.model_construct( operator=self.convert_compare_operator(condition["comparison_operator"]), - left=self._process_list_variable_litearl(condition["variable_selector"]), + left=self._process_list_variable_literal(condition["variable_selector"]), right=self.trans_variable_format( right_value ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( @@ -453,7 +453,7 @@ class DifyConverter(BaseConverter): right_input_type = variable["value_type"] right_value_type = self.variable_type_map(variable["var_type"]) if right_input_type == ValueInputType.VARIABLE: - right_value = self._process_list_variable_litearl(variable.get("value", "")) + right_value = self._process_list_variable_literal(variable.get("value", "")) else: right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) loop_variables.append( @@ -475,10 +475,10 @@ class DifyConverter(BaseConverter): def convert_iteration_node_config(self, node: dict) -> dict: node_data = node["data"] result = IterationNodeConfig.model_construct( - input=self._process_list_variable_litearl(node_data["iterator_selector"]), + input=self._process_list_variable_literal(node_data["iterator_selector"]), parallel=node_data["is_parallel"], parallel_count=node_data["parallel_nums"], - output=self._process_list_variable_litearl(node_data["output_selector"]), + output=self._process_list_variable_literal(node_data["output_selector"]), output_type=self.variable_type_map(node_data.get("output_type")), flatten=node_data["flatten_output"], ).model_dump() @@ -494,8 +494,8 @@ class DifyConverter(BaseConverter): continue assignments.append( AssignmentItem( - variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), - value=self._process_list_variable_litearl( + variable_selector=self._process_list_variable_literal(assignment["variable_selector"]), + value=self._process_list_variable_literal( assignment["value"] ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], operation=self.convert_assignment_operator(assignment["operation"]) @@ -514,7 +514,7 @@ class DifyConverter(BaseConverter): input_variables.append( InputVariable.model_construct( name=input_variable["variable"], - variable=self._process_list_variable_litearl(input_variable["value_selector"]), + variable=self._process_list_variable_literal(input_variable["value_selector"]), ) ) @@ -570,7 +570,7 @@ class DifyConverter(BaseConverter): else: if node_data["body"]["data"]: body_content = (node_data["body"]["data"][0].get("value") or - self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) + self._process_list_variable_literal(node_data["body"]["data"][0].get("file"))) else: body_content = "" @@ -585,7 +585,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -603,7 +603,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -625,7 +625,7 @@ class DifyConverter(BaseConverter): default_header = var["value"] elif var["key"] == "status_code": default_status_code = var["value"] - default_value = HttpErrorDefaultTamplete( + default_value = HttpErrorDefaultTemplate( body=default_body, headers=default_header, status_code=default_status_code, @@ -668,7 +668,7 @@ class DifyConverter(BaseConverter): for variable in node_data["variables"]: mapping.append(VariablesMappingConfig.model_construct( name=variable["variable"], - value=self._process_list_variable_litearl(variable["value_selector"]) + value=self._process_list_variable_literal(variable["value_selector"]) )) result = JinjaRenderNodeConfig.model_construct( template=node_data["template"], @@ -679,14 +679,14 @@ class DifyConverter(BaseConverter): def convert_knowledge_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, detail=f"Please reconfigure the Knowledge Retrieval node.", )) result = KnowledgeRetrievalNodeConfig.model_construct( - query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + query=self._process_list_variable_literal(node_data["query_variable_selector"]), ).model_dump() self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) @@ -695,7 +695,7 @@ class DifyConverter(BaseConverter): def convert_parameter_extractor_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -712,7 +712,7 @@ class DifyConverter(BaseConverter): ) ) result = ParameterExtractorNodeConfig.model_construct( - text=self._process_list_variable_litearl(node_data["query"]), + text=self._process_list_variable_literal(node_data["query"]), params=params, prompt=node_data.get("instruction") ).model_dump() @@ -727,14 +727,14 @@ class DifyConverter(BaseConverter): group_type = {} if not advanced_settings or not advanced_settings["group_enabled"]: group_variables = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in node_data["variables"] ] group_type["output"] = node_data["output_type"] else: for group in advanced_settings["groups"]: group_variables[group["group_name"]] = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in group["variables"] ] group_type[group["group_name"]] = group["output_type"] @@ -751,7 +751,7 @@ class DifyConverter(BaseConverter): def convert_tool_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index 10397ad0..abd95408 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import ( WorkflowParserResult ) from app.core.workflow.adapters.dify.converter import DifyConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ( NodeDefinition, @@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): if not all(field in self.config for field in require_fields): return False if self.config.get("app", {}).get("mode") == "workflow": - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.PLATFORM, detail="workflow mode is not supported" )) @@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): edge = self._convert_edge(edge) if edge: self.edges.append(edge) - # + for variable in self.config.get("workflow").get("conversation_variables"): con_var = self._convert_variable(variable) if variable: self.conv_variables.append(con_var) - # + # for variables in config.get("workflow").get("environment_variables"): # variable = self._convert_variable(variables) # conv_variables.append(variable) @@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "y": node["position"]["y"] + position["y"] } self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, detail="parent cycle node not found" @@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): node_data = node["data"] converter = self.get_node_convert(node_type) if node_type == NodeType.UNKNOWN: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): )) return converter(node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: try: - source = edge["source"] target = edge["target"] label = None @@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): label=label, ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}", )) @@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): description=variable.get("description") ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}", diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py index c0340a5e..cb743c68 100644 --- a/api/app/core/workflow/adapters/errors.py +++ b/api/app/core/workflow/adapters/errors.py @@ -18,7 +18,7 @@ class ExceptionType(StrEnum): UNKNOWN = "unknown" -class ExceptionDefineition(BaseModel): +class ExceptionDefinition(BaseModel): type: ExceptionType detail: str @@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel): name: str | None = None -class UnknowModelWarning(ExceptionDefineition): +class UnknownModelWarning(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id, node_name, model_name): @@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition): ) -class UnknowError(ExceptionDefineition): +class UnknownError(ExceptionDefinition): type: ExceptionType = ExceptionType.UNKNOWN def __init__(self, detail: str, **kwargs): super().__init__(detail=detail, **kwargs) -class UnsupportPlatform(ExceptionDefineition): +class UnsupportedPlatform(ExceptionDefinition): type: ExceptionType = ExceptionType.PLATFORM def __init__(self, platform: str): - super().__init__(detail=f"Unsupport platform {platform}") + super().__init__(detail=f"Unsupported platform {platform}") -class UnsupportVariableType(ExceptionDefineition): +class UnsupportedVariableType(ExceptionDefinition): type: ExceptionType = ExceptionType.VARIABLE def __init__(self, scope, name, var_type: str, **kwargs): - super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs) -class InvalidConfiguration(ExceptionDefineition): +class InvalidConfiguration(ExceptionDefinition): type: ExceptionType = ExceptionType.CONFIG def __init__(self): super().__init__(detail="Invalid workflow configuration format") -class UnsupportNodeType(ExceptionDefineition): +class UnsupportedNodeType(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id: str, node_type: str): - super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") + super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py index 3516cb58..a2608a01 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import ( BasePlatformAdapter, WorkflowParserResult ) -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition @@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: node_type = self.map_node_type(node["type"]) if node_type == NodeType.UNKNOWN: - self.errors.append(UnsupportNodeType( + self.errors.append(UnsupportedNodeType( node_id=node_id, node_type=node["type"] )) @@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): return NodeDefinition(**node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, node_name=node_name, @@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: try: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"edge {edge.get('id')} skipped: source or target node not found" )) return None return EdgeDefinition(**edge) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}" )) @@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: return VariableDefinition(**variable) except Exception as e: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}" diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py index 031c7025..e96e0bf2 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py @@ -1,6 +1,6 @@ # -*- coding: UTF-8 -*- from app.core.workflow.adapters.base_converter import BaseConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.configs import ( StartNodeConfig, @@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter): try: return config_cls.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index c5cf3324..29f46765 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -69,11 +69,12 @@ class GraphBuilder: for node in self.nodes if node.get("type") == "end" and node.get("id") in self.reachable_nodes ] + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._adj: dict[str, list[str]] = defaultdict(list) + self._build_reverse_adj() self.add_edges() # EDGES MUST BE ADDED AFTER NODES ARE ADDED. - self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._build_reverse_adj() self._analyze_end_node_output() @property @@ -115,6 +116,7 @@ class GraphBuilder: self._reverse_adj[edge.get("target")].append({ "id": edge["source"], "branch": edge.get("label") }) + self._adj[edge.get("source")].append(edge["target"]) def _find_upstream_activation_dep( self, @@ -413,11 +415,12 @@ class GraphBuilder: # Add conditional edges for source_node, branches in conditional_edges.items(): def make_router(src, branch_list): - """reate a router function for each source node that routes to a NOP node for later merging.""" + """Create a router function for each source node that routes to a NOP node for later merging.""" def make_branch_node(node_name, targets): def node(s): - # NOTE: NOP NODE MUST NOT MODIFY STATE + # NOTE: NOP NODE USED FOR ROUTING ONLY. + # MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS. return { "activate": { node_id: s["activate"][node_name] @@ -504,11 +507,9 @@ class GraphBuilder: logger.debug(f"Added waiting edge: {sources} -> {target}") # Connect End nodes to the global END node - for end_node in self.end_nodes: - end_node_id = end_node.get("id") - if end_node_id: - self.graph.add_edge(end_node_id, END) - logger.debug(f"Added edge: {end_node_id} -> END") + for node in self.reachable_nodes: + if not self._adj[node]: + self.graph.add_edge(node, END) return def build(self) -> CompiledStateGraph: diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index e5a03c1c..be0c957a 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -2,6 +2,7 @@ # Author: Eternity # @Email: 1533512157@qq.com # @Time : 2026/2/10 13:33 +from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.variable_pool import VariablePool @@ -9,6 +10,7 @@ class WorkflowResultBuilder: def build_final_output( self, result: dict, + execution_context: ExecutionContext, variable_pool: VariablePool, elapsed_time: float, final_output: str, @@ -26,6 +28,8 @@ class WorkflowResultBuilder: - "node_outputs" (dict): Outputs of executed nodes. - "messages" (list): Conversation messages exchanged during execution. - "error" (str, optional): Error message if any node failed. + execution_context (ExecutionContext): The execution context containing metadata like + execution ID, workspace ID, and user ID.) variable_pool (VariablePool): Variable Pool elapsed_time (float): Total execution time in seconds. final_output (Any): The aggregated or final output content of the workflow @@ -48,18 +52,23 @@ class WorkflowResultBuilder: """ node_outputs = result.get("node_outputs", {}) token_usage = self.aggregate_token_usage(node_outputs) - conversation_id = variable_pool.get_value("sys.conversation_id") + conversation_vars = {} + sys_vars = {} + + if variable_pool: + conversation_vars = variable_pool.get_all_conversation_vars() + sys_vars = variable_pool.get_all_system_vars() return { "status": "completed" if success else "failed", "output": final_output, "variables": { - "conv": variable_pool.get_all_conversation_vars(), - "sys": variable_pool.get_all_system_vars() + "conv": conversation_vars, + "sys": sys_vars }, "node_outputs": node_outputs, "messages": result.get("messages", []), - "conversation_id": conversation_id, + "conversation_id": execution_context.conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, "error": result.get("error"), diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index 48eafaa9..036ce0e8 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,6 +12,7 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + conversation_id: str memory_storage_type: str user_rag_memory_id: str checkpoint_config: RunnableConfig @@ -22,6 +23,7 @@ class ExecutionContext(BaseModel): execution_id: str, workspace_id: str, user_id: str, + conversation_id: str, memory_storage_type: str, user_rag_memory_id: str ): @@ -29,6 +31,7 @@ class ExecutionContext(BaseModel): execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + conversation_id=conversation_id, memory_storage_type=memory_storage_type, user_rag_memory_id=user_rag_memory_id, diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 6a127e96..1170d66c 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -3,6 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 13:51 import datetime +import time import logging from typing import Any @@ -82,6 +83,7 @@ class WorkflowExecutor: CompiledStateGraph: The compiled and ready-to-run state graph. """ logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") + start_time = time.time() builder = GraphBuilder( self.workflow_config, stream=stream, @@ -96,7 +98,8 @@ class WorkflowExecutor: variable_pool=self.variable_pool, execution_id=self.execution_context.execution_id ) - logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") + logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, " + f"cost: {time.time() - start_time:.4f}s") return self.graph @@ -134,94 +137,12 @@ class WorkflowExecutor: return event.get("data") return self.result_builder.build_final_output( {"error": "Workflow execution did not end as expected"}, + self.execution_context, self.variable_pool, (datetime.datetime.now() - start).total_seconds(), "", success=False ) - # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") - # - # start_time = datetime.datetime.now() - # - # # Execute the workflow - # try: - # # Build the workflow graph - # graph = self.build_graph() - # - # # Initialize the variable pool with input data - # await self.variable_initializer.initialize( - # variable_pool=self.variable_pool, - # input_data=input_data, - # execution_context=self.execution_context - # ) - # initial_state = self.state_manager.create_initial_state( - # workflow_config=self.workflow_config, - # input_data=input_data, - # execution_context=self.execution_context, - # start_node_id=self.start_node_id - # ) - # - # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) - # - # # Aggregate output from all End nodes - # full_content = '' - # for end_id in self.stream_coordinator.end_outputs.keys(): - # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) - # - # # Append messages for user and assistant - # if input_data.get("files"): - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "user", - # "content": input_data.get("files") - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # else: - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # # Calculate elapsed time - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.info( - # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") - # - # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) - # - # except Exception as e: - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", - # exc_info=True) - # return { - # "status": "failed", - # "error": str(e), - # "output": None, - # "node_outputs": {}, - # "elapsed_time": elapsed_time, - # "token_usage": None - # } async def execute_stream( self, @@ -255,7 +176,7 @@ class WorkflowExecutor: "data": { "execution_id": self.execution_context.execution_id, "workspace_id": self.execution_context.workspace_id, - "conversation_id": input_data.get("conversation_id"), + "conversation_id": self.execution_context.conversation_id, "timestamp": int(start_time.timestamp() * 1000) } } @@ -376,6 +297,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -396,6 +318,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -432,6 +355,7 @@ async def execute_workflow( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + conversation_id=input_data.get("conversation_id"), memory_storage_type=memory_storage_type, user_rag_memory_id=user_rag_memory_id ) @@ -471,6 +395,7 @@ async def execute_workflow_stream( workspace_id=workspace_id, user_id=user_id, memory_storage_type=memory_storage_type, + conversation_id=input_data.get("conversation_id"), user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 8959e27c..7b146a9c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -64,9 +64,7 @@ class AgentNode(BaseNode): if not release: raise ValueError(f"Agent 不存在: {agent_id}") - - return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 7f2b8aa6..34b7dfa3 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -315,8 +315,8 @@ class BaseNode(ABC): elapsed_time = (time.time() - start_time) * 1000 - logger.info(f"Node {self.node_id} streaming execution finished, " - f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") + logger.debug(f"Node {self.node_id} streaming execution finished, " + f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) @@ -644,7 +644,7 @@ class BaseNode(ABC): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] with get_db_read() as db: - multimodel_service = MultimodalService(db, api_config=api_config) + multimodal_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( type=content.type, url=content.url, @@ -653,7 +653,7 @@ class BaseNode(ABC): upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, ) file_obj.set_content(content.get_content()) - message = await multimodel_service.process_files( + message = await multimodal_service.process_files( [file_obj], ) content.set_content(file_obj.get_content()) @@ -661,7 +661,7 @@ class BaseNode(ABC): content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message return message return None - raise TypeError(f'Unexpect input value type - {type(content)}') + raise TypeError(f'Unexpected input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index 5c2a6c2a..02df5091 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -1,9 +1,7 @@ """End 节点配置""" - from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig class EndNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2799316a..770cf328 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -36,8 +36,6 @@ class EndNode(BaseNode): Returns: 最终输出字符串 """ - logger.info(f"节点 {self.node_id} (End) 开始执行") - # 获取配置的输出模板 output_template = self.config.get("output") @@ -46,11 +44,4 @@ class EndNode(BaseNode): output = self._render_template(output_template, variable_pool, strict=False) else: output = "" - - # 统计信息(用于日志) - node_outputs = state.get("node_outputs", {}) - total_nodes = len(node_outputs) - - logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - return output diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 43ab593b..5a603ac9 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -28,7 +28,7 @@ class NodeType(StrEnum): NOTES = "notes" -BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] +BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index fe38fafb..e1b84f0c 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel): ) -class HttpErrorDefaultTamplete(BaseModel): +class HttpErrorDefaultTemplate(BaseModel): body: str = Field( default="", description="Default body returned on HTTP error", @@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel): description="Error handling strategy: 'none', 'default', or 'branch'", ) - default: HttpErrorDefaultTamplete | None = Field( + default: HttpErrorDefaultTemplate | None = Field( default=None, description="Default response template for error handling", ) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 23378c83..8aa8726e 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput -from app.core.workflow.utils.file_processer import mime_to_file_type +from app.core.workflow.utils.file_processor import mime_to_file_type from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.schemas import FileType, TransferMethod diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index a9618f7b..58567e6a 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -62,7 +62,6 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ self.typed_config = StartNodeConfig(**self.config) - logger.info(f"节点 {self.node_id} (Start) 开始执行") # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(variable_pool) @@ -77,9 +76,9 @@ class StartNode(BaseNode): **custom_vars # 自定义变量作为节点输出的一部分 } - logger.info( - f"节点 {self.node_id} (Start) 执行完成," - f"输出了 {len(custom_vars)} 个自定义变量" + logger.debug( + f"Node {self.node_id} (Start) execution completed, " + f"outputting {len(custom_vars)} custom variables" ) return result diff --git a/api/app/core/workflow/utils/file_processer.py b/api/app/core/workflow/utils/file_processor.py similarity index 100% rename from api/app/core/workflow/utils/file_processer.py rename to api/app/core/workflow/utils/file_processor.py diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index fe4aea19..683ccb98 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -6,6 +6,7 @@ import copy import logging +from collections import defaultdict, deque from typing import Any, Union, TYPE_CHECKING from app.core.workflow.nodes.enums import NodeType @@ -119,7 +120,6 @@ class WorkflowValidator: errors = [] graphs = cls.get_subgraph(workflow_config) - logger.info(graphs) for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) @@ -204,18 +204,18 @@ class WorkflowValidator: Returns: 可达节点 ID 集合 """ + adj = defaultdict(list) + for edge in edges: + adj[edge["source"]].append(edge["target"]) + reachable = {start_id} - queue = [start_id] - + queue = deque([start_id]) while queue: - current = queue.pop(0) - for edge in edges: - if edge.get("source") == current: - target = edge.get("target") - if target and target not in reachable: - reachable.add(target) - queue.append(target) - + current = queue.popleft() + for target in adj[current]: + if target not in reachable: + reachable.add(target) + queue.append(target) return reachable @staticmethod diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 5e8e3f1e..79e023c1 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -54,7 +54,7 @@ class DictVariable(BaseVariable): def valid_value(self, value) -> dict: if not isinstance(value, dict): - raise TypeError(f"Value must be a dict - {type(value)}:{value}") + raise TypeError(f"Value must be a dict - {type(value)}:{value}") return value def to_literal(self) -> str: diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py index 2b36c5ea..fd8f25f3 100644 --- a/api/app/services/workflow_import_service.py +++ b/api/app/services/workflow_import_service.py @@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get from app.core.config import settings from app.core.exceptions import BusinessException from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult -from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.schemas import AppCreate from app.schemas.workflow_schema import WorkflowConfigCreate @@ -46,7 +46,7 @@ class WorkflowImportService: success=False, temp_id=None, workflow_id=None, - errors=[UnsupportPlatform(platform=platform)] + errors=[UnsupportedPlatform(platform=platform)] ) adapter = self.registry.get_adapter(platform, config) From 45eef128427cca589830f9a0290699b7f5ac7c5f Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 25 Mar 2026 14:11:55 +0800 Subject: [PATCH 015/120] perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes --- api/app/core/workflow/engine/graph_builder.py | 41 +++++++++++-------- api/app/core/workflow/executor.py | 3 +- .../core/workflow/nodes/cycle_graph/node.py | 12 +++--- api/app/core/workflow/validator.py | 10 +---- 4 files changed, 31 insertions(+), 35 deletions(-) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 29f46765..d092db5b 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -61,21 +61,11 @@ class GraphBuilder: else: self.variable_pool = VariablePool() - self.graph = StateGraph(WorkflowState) - self.add_nodes() - self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) - self.end_nodes = [ - node - for node in self.nodes - if node.get("type") == "end" and node.get("id") in self.reachable_nodes - ] - self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._adj: dict[str, list[str]] = defaultdict(list) - self._build_reverse_adj() - self.add_edges() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. - - self._analyze_end_node_output() + self.graph: StateGraph | None = None + self.reachable_nodes: set[str] | None = None + self.end_nodes: list[dict] = [] + self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list) + self._adj: dict[str, list[str]] | None = defaultdict(list) @property def nodes(self) -> list[dict[str, Any]]: @@ -109,7 +99,7 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _build_reverse_adj(self): + def _build_adj(self): for edge in self.edges: if edge["source"] not in self.reachable_nodes: continue @@ -513,6 +503,21 @@ class GraphBuilder: return def build(self) -> CompiledStateGraph: + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._adj: dict[str, list[str]] = defaultdict(list) + self._build_adj() + self.add_edges() + # EDGES MUST BE ADDED AFTER NODES ARE ADDED. + + self._analyze_end_node_output() checkpointer = InMemorySaver() - self.graph = self.graph.compile(checkpointer=checkpointer) - return self.graph + return self.graph.compile(checkpointer=checkpointer) + diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 1170d66c..0a820826 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -88,9 +88,10 @@ class WorkflowExecutor: self.workflow_config, stream=stream, ) + + self.graph = builder.build() self.start_node_id = builder.start_node_id self.variable_pool = builder.variable_pool - self.graph = builder.build() self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.event_handler = EventStreamHandler( diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 71e0dbdb..16939bac 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - - self.cycle_nodes = list() # Nodes belonging to this cycle - self.cycle_edges = list() # Edges connecting nodes within the cycle + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None self.child_variable_pool: VariablePool | None = None - self.build_graph() - self.iteration_flag = True def _output_types(self) -> dict[str, VariableType]: outputs = {"__child_state": VariableType.ARRAY_OBJECT} @@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode): 3. Compile the graph for runtime execution """ from app.core.workflow.engine.graph_builder import GraphBuilder - self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { @@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode): subgraph=True, variable_pool=self.child_variable_pool ) - self.start_node_id = builder.start_node_id self.graph = builder.build() + self.start_node_id = builder.start_node_id self.child_variable_pool = builder.variable_pool async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ + self.build_graph() if self.node_type == NodeType.LOOP: return await LoopRuntime( start_id=self.start_node_id, @@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode): raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + self.build_graph() if self.node_type == NodeType.LOOP: yield { "__final__": True, diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 683ccb98..0ad74865 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -183,7 +183,7 @@ class WorkflowValidator: has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) if has_cycle: errors.append( - f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}" ) # 8. 验证变量名 @@ -229,10 +229,6 @@ class WorkflowValidator: Returns: (has_cycle, cycle_path): 是否有循环和循环路径 """ - # 排除 loop 类型的节点 - loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} - - # 构建邻接表(排除 loop 节点的边和错误边) graph: dict[str, list[str]] = {} for edge in edges: source = edge.get("source") @@ -243,10 +239,6 @@ class WorkflowValidator: if edge_type == "error": continue - # 如果涉及 loop 节点,跳过 - if source in loop_nodes or target in loop_nodes: - continue - if source and target: if source not in graph: graph[source] = [] From 85daf576e980d992884756e8768c4c6ad2b76a0b Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 25 Mar 2026 17:03:12 +0800 Subject: [PATCH 016/120] perf(workflow): Optimize downstream node activation method to reduce performance overhead --- .../extraction_orchestrator.py | 1 - api/app/core/workflow/engine/graph_builder.py | 78 ++++++++++--------- api/app/core/workflow/nodes/assigner/node.py | 4 +- api/app/core/workflow/nodes/base_node.py | 70 ++++++++--------- api/app/core/workflow/nodes/code/node.py | 4 +- .../core/workflow/nodes/cycle_graph/node.py | 18 ++--- .../core/workflow/nodes/http_request/node.py | 4 +- api/app/core/workflow/nodes/if_else/node.py | 4 +- .../core/workflow/nodes/jinja_render/node.py | 4 +- api/app/core/workflow/nodes/knowledge/node.py | 4 +- api/app/core/workflow/nodes/llm/node.py | 4 +- api/app/core/workflow/nodes/memory/node.py | 8 +- api/app/core/workflow/nodes/node_factory.py | 6 +- .../nodes/parameter_extractor/node.py | 4 +- .../nodes/question_classifier/node.py | 4 +- api/app/core/workflow/nodes/start/node.py | 10 +-- api/app/core/workflow/nodes/tool/node.py | 4 +- .../nodes/variable_aggregator/node.py | 4 +- .../core/workflow/utils/template_renderer.py | 7 +- 19 files changed, 122 insertions(+), 120 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index da10c497..e0b86d8c 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1099,7 +1099,6 @@ class ExtractionOrchestrator: metadata=chunk.metadata, ) chunk_nodes.append(chunk_node) - logger.error(f"chunk file: {chunk.files}") for p, file_type in chunk.files: diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index d092db5b..daef6e82 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -7,7 +7,7 @@ import re import uuid from collections import defaultdict from functools import lru_cache -from typing import Any, Iterable +from typing import Any, Iterable, Callable from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, END @@ -41,39 +41,31 @@ class GraphBuilder: self, workflow_config: dict[str, Any], stream: bool = False, - subgraph: bool = False, + cycle: str = '', variable_pool: VariablePool | None = None ): self.workflow_config = workflow_config self.stream = stream - self.subgraph = subgraph + self.cycle = cycle self.start_node_id: str | None = None - self.node_map = {node["id"]: node for node in self.nodes} + self.node_map: dict[str, dict] = {} self.end_node_map: dict[str, StreamOutputConfig] = {} - self._find_upstream_activation_dep = lru_cache( - maxsize=len(self.nodes) * 2 - )(self._find_upstream_activation_dep) + self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep if variable_pool: self.variable_pool = variable_pool else: self.variable_pool = VariablePool() self.graph: StateGraph | None = None + self.nodes: list = [] + self.edges: list = [] self.reachable_nodes: set[str] | None = None self.end_nodes: list[dict] = [] - self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list) - self._adj: dict[str, list[str]] | None = defaultdict(list) - - @property - def nodes(self) -> list[dict[str, Any]]: - return self.workflow_config.get("nodes", []) - - @property - def edges(self) -> list[dict[str, Any]]: - return self.workflow_config.get("edges", []) + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._adj: dict[str, list[str]] = defaultdict(list) def get_node_type(self, node_id: str) -> str: """Retrieve the type of node given its ID. @@ -294,22 +286,13 @@ class GraphBuilder: """ for node in self.nodes: node_type = node.get("type") - if node_type == NodeType.NOTES: - continue node_id = node.get("id") - cycle_node = node.get("cycle") - if cycle_node: - # Nodes within a loop subgraph are constructed by CycleGraphNode - if not self.subgraph: - continue - - # Record start and end node IDs - if node_type in [NodeType.START, NodeType.CYCLE_START]: - self.start_node_id = node_id + if node_id not in self.reachable_nodes: + continue # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph - node_instance = NodeFactory.create_node(node, self.workflow_config) + node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id]) if node_type in BRANCH_NODES: @@ -503,21 +486,46 @@ class GraphBuilder: return def build(self) -> CompiledStateGraph: - self.graph = StateGraph(WorkflowState) - self.add_nodes() + nodes = self.workflow_config.get("nodes", []) + edges = self.workflow_config.get("edges", []) + + for node in nodes: + if (node.get("cycle") or '') == self.cycle: + node_type = node.get("type") + if node_type in [NodeType.START, NodeType.CYCLE_START]: + self.start_node_id = node.get("id") + elif node_type == NodeType.NOTES: + continue + self.nodes.append(node) + self.node_map[node.get("id")] = node + + for edge in edges: + source_in = edge.get("source") in self.node_map + target_in = edge.get("target") in self.node_map + if source_in ^ target_in: + raise ValueError( + f"Cycle node is connected to external node, " + f"source: {edge.get('source')}, target: {edge.get('target')}" + ) + + if source_in and target_in: + self.edges.append(edge) + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.end_nodes = [ node for node in self.nodes if node.get("type") == "end" and node.get("id") in self.reachable_nodes ] - self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._adj: dict[str, list[str]] = defaultdict(list) self._build_adj() + self._find_upstream_activation_dep: Callable = lru_cache( + maxsize=len(self.nodes)*2 + )(self._find_upstream_activation_dep) + + self.graph = StateGraph(WorkflowState) + self.add_nodes() self.add_edges() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. self._analyze_end_node_output() checkpointer = InMemorySaver() return self.graph.compile(checkpointer=checkpointer) - diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 4c897d5a..f5bdf000 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 34b7dfa3..0b31c9e3 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -28,7 +28,7 @@ class BaseNode(ABC): All node types should inherit from this class and implement the `execute` method. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): """Initialize the node. Args: @@ -41,6 +41,7 @@ class BaseNode(ABC): self.node_type = node_config["type"] self.cycle = node_config.get("cycle") self.node_name = node_config.get("name", self.node_id) + self.down_stream_nodes = down_stream_nodes # 使用 or 运算符处理 None 值 self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} @@ -93,18 +94,16 @@ class BaseNode(ABC): dict: A dict with a single key 'activate', mapping node IDs to their activation status (True/False). """ - edges = self.workflow_config.get("edges") - under_stream_nodes = [ - edge.get("target") - for edge in edges - if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES - ] - return { - "activate": { - node_id: self.check_activate(state) - for node_id in under_stream_nodes - } | {self.node_id: self.check_activate(state)} - } + activate_flag = self.check_activate(state) + + if self.node_type not in BRANCH_NODES: + activate = {node_id: activate_flag for node_id in self.down_stream_nodes} + else: + activate = {} + + activate[self.node_id] = activate_flag + + return {"activate": activate} @abstractmethod async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -428,8 +427,8 @@ class BaseNode(ABC): when an error edge exists. If no error edge exists, this method raises an exception to stop the workflow. """ - # Check if the node has an error edge defined - error_edge = self._find_error_edge() + # # Check if the node has an error edge defined + # error_edge = self._find_error_edge() # Extract input data (for logging or audit purposes) input_data = self._extract_input(state, variable_pool) @@ -447,27 +446,26 @@ class BaseNode(ABC): "error": error_message } - if error_edge: - # If an error edge exists, log a warning and continue to error node - logger.warning( - f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" - ) - return { - "node_outputs": { - self.node_id: node_output - }, - "error": error_message, - "error_node": self.node_id - } - else: - # If no error edge, send the error via stream writer and stop the workflow - writer = get_stream_writer() - writer({ - "type": "node_error", - **node_output - }) - logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") - raise Exception(f"Node {self.node_id} execution failed: {error_message}") + # if error_edge: + # # If an error edge exists, log a warning and continue to error node + # logger.warning( + # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" + # ) + # return { + # "node_outputs": { + # self.node_id: node_output + # }, + # "error": error_message, + # "error_node": self.node_id + # } + # else: + writer = get_stream_writer() + writer({ + "type": "node_error", + **node_output + }) + logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") + raise Exception(f"Node {self.node_id} execution failed: {error_message}") def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """Extracts the input data for this node (used for logging or audit). diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 1e055002..d89b208b 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -51,8 +51,8 @@ console.log(result) class CodeNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: CodeNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 16939bac..fc80939f 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -30,8 +30,8 @@ class CycleGraphNode(BaseNode): It acts as a container and execution controller for a subgraph. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.start_node_id = None # ID of the start node within the cycle @@ -115,11 +115,11 @@ class CycleGraphNode(BaseNode): else: remain_edges.append(edge) - # Update workflow_config by removing cycle nodes and internal edges - self.workflow_config["nodes"] = [ - node for node in nodes if node.get("cycle") != self.node_id - ] - self.workflow_config["edges"] = remain_edges + # # Update workflow_config by removing cycle nodes and internal edges + # self.workflow_config["nodes"] = [ + # node for node in nodes if node.get("cycle") != self.node_id + # ] + # self.workflow_config["edges"] = remain_edges return cycle_nodes, cycle_edges @@ -140,8 +140,8 @@ class CycleGraphNode(BaseNode): "nodes": self.cycle_nodes, "edges": self.cycle_edges, }, - subgraph=True, - variable_pool=self.child_variable_pool + variable_pool=self.child_variable_pool, + cycle=self.node_id ) self.graph = builder.build() self.start_node_id = builder.start_node_id diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 8aa8726e..086bee4a 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode): or a branch identifier string when error branching is enabled. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: HttpRequestNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 5d2bdf9a..ec46b20b 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class IfElseNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: IfElseNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index e13709d4..abf21524 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: JinjaRenderNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index d3e9efd9..92699cb4 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.vector_service: ElasticSearchVector | None = None diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 66a0f1ac..a691001f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -70,8 +70,8 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: LLMNodeConfig | None = None self.messages = [] diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index a28247e4..73c52b79 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -14,8 +14,8 @@ from app.tasks import write_message_task class MemoryReadNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryReadNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: @@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode): class MemoryWriteNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryWriteNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 864e3251..9e5a7d24 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -104,13 +104,15 @@ class NodeFactory: def create_node( cls, node_config: dict[str, Any], - workflow_config: dict[str, Any] + workflow_config: dict[str, Any], + down_stream_nodes: list[str] ) -> WorkflowNode | None: """创建节点实例 Args: node_config: 节点配置 workflow_config: 工作流配置 + down_stream_nodes: 下游节点 Returns: 节点实例或 None(对于不支持的节点类型) @@ -127,7 +129,7 @@ class NodeFactory: # 创建节点实例 logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") - return node_class(node_config, workflow_config) + return node_class(node_config, workflow_config, down_stream_nodes) @classmethod def get_supported_types(cls) -> list[str]: diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index acac09e4..3dc5fcc3 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class ParameterExtractorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ParameterExtractorNodeConfig | None = None self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 5cebd886..31fadaf6 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1" class QuestionClassifierNode(BaseNode): """问题分类器节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 58567e6a..7a324cc4 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -27,14 +27,8 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - """初始化 Start 节点 - - Args: - node_config: 节点配置 - workflow_config: 工作流配置 - """ - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) # 解析并验证配置 self.typed_config: StartNodeConfig | None = None diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 0e9d3c62..72c5c6a8 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}") class ToolNode(BaseNode): """工具节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ToolNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index de82f8ff..9a9c5566 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class VariableAggregatorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: VariableAggregatorNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 424fdf20..6a73efc4 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -153,7 +153,8 @@ class TemplateRenderer: # 全局渲染器实例(严格模式) -_default_renderer = TemplateRenderer(strict=True) +_strict_renderer = TemplateRenderer(strict=True) +_lenient_renderer = TemplateRenderer(strict=False) def render_template( @@ -184,7 +185,7 @@ def render_template( ... ) '请分析: 这是一段文本' """ - renderer = TemplateRenderer(strict=strict) + renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) @@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]: Returns: 错误列表 """ - return _default_renderer.validate(template) + return _strict_renderer.validate(template) From 1794f8f209f39ec1050a4c7ecc60fa906f96f20a Mon Sep 17 00:00:00 2001 From: wxy Date: Wed, 25 Mar 2026 17:14:11 +0800 Subject: [PATCH 017/120] feat: block deactivating user who is tenant contact --- api/app/models/tenant_model.py | 11 +++++++++++ api/app/services/user_service.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 044857d2..a92b5629 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -23,6 +23,17 @@ class Tenants(Base): # 国际化语言配置字段 default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 + + # 租户联系信息 + contact_name = Column(String(100), nullable=True) # 联系人姓名 + contact_email = Column(String(255), nullable=True) # 联系人邮箱 + contact_phone = Column(String(50), nullable=True) # 联系人电话 + + # 租户套餐信息 + plan = Column(String(50), nullable=True) # 套餐类型 + plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间 + api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 + status = Column(String(50), nullable=True, default='active') # 租户状态 # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index b5522b74..3122d282 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -250,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: } ) + # 检查是否为租户联系人 + from app.models.tenant_model import Tenants + tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first() + if tenant and tenant.contact_email and tenant.contact_email == db_user.email: + business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}") + raise BusinessException( + "该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员", + code=BizCode.FORBIDDEN, + context={ + "user_id": str(user_id_to_deactivate), + "tenant_id": str(db_user.tenant_id) + } + ) + # 停用用户 business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})") db_user.is_active = False From caab58dd2fbe4ebb997b3c80613623664e1876f1 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 25 Mar 2026 17:54:27 +0800 Subject: [PATCH 018/120] fix(file and app): 1. Handle the encoding issue when downloading Markdown files; 2. Experience the sharing of memory configuration --- .../controllers/file_storage_controller.py | 10 +++++++--- .../controllers/public_share_controller.py | 1 + api/app/core/storage/base.py | 18 +++++++----------- api/app/core/storage/local.py | 8 +++++++- api/app/core/storage/oss.py | 16 +++++++++++++--- api/app/core/storage/s3.py | 19 +++++++++++++------ api/app/services/file_storage_service.py | 15 +++++++++------ 7 files changed, 57 insertions(+), 30 deletions(-) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index 14962a72..4e1ba74c 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -574,8 +574,12 @@ async def get_file_url( # For local storage, generate signed URL with expiration url = generate_signed_url(str(file_id), expires) else: - # For remote storage (OSS/S3), get presigned URL - url = await storage_service.get_file_url(file_key, expires=expires) + # For remote storage (OSS/S3), get presigned URL with forced download + url = await storage_service.get_file_url( + file_key, + expires=expires, + file_name=file_metadata.file_name, + ) url = _match_scheme(request, url) api_logger.info(f"Generated file URL: file_id={file_id}") @@ -786,7 +790,7 @@ async def permanent_download_file( # For remote storage, redirect to presigned URL with long expiration try: # Use a very long expiration (7 days max for most cloud providers) - presigned_url = await storage_service.get_file_url(file_key, expires=604800) + presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name) presigned_url = _match_scheme(request, presigned_url) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except Exception as e: diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 33d7b60c..f5284b46 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -669,6 +669,7 @@ async def config_query( content = { "app_type": release.app.type, "variables": release.config.get("variables"), + "memory": release.config.get("memory", {}).get("enabled"), "features": release.config.get("features") } elif release.app.type == AppType.MULTI_AGENT: diff --git a/api/app/core/storage/base.py b/api/app/core/storage/base.py index 8ab0fcde..09824c3f 100644 --- a/api/app/core/storage/base.py +++ b/api/app/core/storage/base.py @@ -109,17 +109,13 @@ class StorageBackend(ABC): pass @abstractmethod - async def get_url(self, file_key: str, expires: int = 3600) -> str: - """ - Get an access URL for the file. - - Args: - file_key: Unique identifier for the file in the storage system. - expires: URL validity period in seconds (default: 1 hour). - - Returns: - URL for accessing the file. - """ + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: + """Get an access URL for the file.""" pass async def get_permanent_url(self, file_key: str) -> Optional[str]: diff --git a/api/app/core/storage/local.py b/api/app/core/storage/local.py index 4b8ae829..13adfc20 100644 --- a/api/app/core/storage/local.py +++ b/api/app/core/storage/local.py @@ -210,7 +210,12 @@ class LocalStorage(StorageBackend): cause=e, ) - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: """ Get an access URL for the file. @@ -220,6 +225,7 @@ class LocalStorage(StorageBackend): Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (not used for local storage). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A relative URL path for accessing the file. diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 27669ffa..1db86fef 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK. import io import logging +import urllib.parse from typing import AsyncIterator, Optional import oss2 @@ -242,24 +243,33 @@ class OSSStorage(StorageBackend): logger.error(f"Failed to check file existence in OSS {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: - url = self.bucket.sign_url("GET", file_key, expires) + params = {} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" + url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/core/storage/s3.py b/api/app/core/storage/s3.py index c7b33ffe..f156f4a7 100644 --- a/api/app/core/storage/s3.py +++ b/api/app/core/storage/s3.py @@ -6,6 +6,7 @@ using the boto3 SDK. """ import io +import urllib.parse import logging from typing import AsyncIterator, Optional @@ -352,31 +353,37 @@ class S3Storage(StorageBackend): logger.error(f"Failed to check file existence in S3 {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: + params = {"Bucket": self.bucket_name, "Key": file_key} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" url = self.client.generate_presigned_url( "get_object", - Params={ - "Bucket": self.bucket_name, - "Key": file_key, - }, + Params=params, ExpiresIn=expires, ) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/services/file_storage_service.py b/api/app/services/file_storage_service.py index 2ebc5d9a..5897936b 100644 --- a/api/app/services/file_storage_service.py +++ b/api/app/services/file_storage_service.py @@ -325,27 +325,30 @@ class FileStorageService: ) raise - async def get_file_url(self, file_key: str, expires: int = 3600) -> str: + async def get_file_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get an access URL for a file. Args: file_key: The file key. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: URL for accessing the file. """ logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") - try: - url = await self.storage.get_url(file_key, expires) + url = await self.storage.get_url(file_key, expires, file_name=file_name) logger.debug(f"File URL generated: file_key={file_key}") return url except Exception as e: - logger.error( - f"Error getting file URL: file_key={file_key}, error={str(e)}" - ) + logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}") raise From 14413fd4132f8261e375de41f024465ccd846de6 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 17:54:52 +0800 Subject: [PATCH 019/120] [changes] Statistical analysis of shared and non-shared applications in the RAG storage mode --- api/app/controllers/memory_dashboard_controller.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index cc0efab3..fe4337d1 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -663,9 +663,12 @@ async def dashboard_data( rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 - from app.repositories import app_repository - apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - rag_data["total_app"] = len(apps_orm) + # 包含自有app + 被分享给本工作空间的app + from app.services import app_service as _app_svc + _, total_app = _app_svc.AppService(db).list_apps( + workspace_id=workspace_id, include_shared=True, pagesize=1 + ) + rag_data["total_app"] = total_app # total_knowledge: 使用 total_kb(总知识库数) total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) @@ -687,7 +690,7 @@ async def dashboard_data( api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}") rag_data["total_api_call"] = 0 - api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") + api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") except Exception as e: api_logger.warning(f"获取RAG相关数据失败: {str(e)}") From 294ee49d599a55d62542ad439f3dd2fcfc9c0404 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 25 Mar 2026 18:06:01 +0800 Subject: [PATCH 020/120] fix(file and app): embedding and volcano model --- api/app/core/models/embedding.py | 2 +- api/app/core/models/generation.py | 1 - api/app/services/generation_service.py | 6 ++---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py index 9ccf53de..3269e1d0 100644 --- a/api/app/core/models/embedding.py +++ b/api/app/core/models/embedding.py @@ -90,7 +90,7 @@ class RedBearEmbeddings(Embeddings): input=contents, **kwargs ) - return [item.embedding for item in response.data] + return [response.data.embedding] async def aembed_multimodal( self, diff --git a/api/app/core/models/generation.py b/api/app/core/models/generation.py index 98b23fbf..b6388d3f 100644 --- a/api/app/core/models/generation.py +++ b/api/app/core/models/generation.py @@ -64,7 +64,6 @@ class RedBearImageGenerator: prompt: 提示词 image: 参考图片URL或URL列表(图文生图/多图融合) size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素) - n: 生成数量 output_format: 输出格式,如 "png", "jpg" response_format: 返回格式,"url" 或 "b64_json" watermark: 是否添加水印 diff --git a/api/app/services/generation_service.py b/api/app/services/generation_service.py index e7800ef6..2505793c 100644 --- a/api/app/services/generation_service.py +++ b/api/app/services/generation_service.py @@ -25,8 +25,7 @@ class GenerationService: self, model_config_id: str, prompt: str, - size: Optional[str] = "1024x1024", - n: int = 1, + size: Optional[str] = "2k", **kwargs ) -> Dict[str, Any]: """ @@ -36,7 +35,6 @@ class GenerationService: model_config_id: 模型配置ID prompt: 提示词 size: 图片尺寸 - n: 生成数量 **kwargs: 其他参数 Returns: @@ -69,7 +67,7 @@ class GenerationService: # 生成图片 generator = RedBearImageGenerator(config) - result = await generator.agenerate(prompt, size, n, **kwargs) + result = await generator.agenerate(prompt, size, **kwargs) return result From 2df615eca0771c0c80991042e2dde57b8eb511b8 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 25 Mar 2026 18:46:43 +0800 Subject: [PATCH 021/120] fix(mcp market): Handling 401 error --- .../controllers/mcp_market_config_controller.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 0f2da3b0..6f27d87a 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -91,9 +91,11 @@ async def get_mcp_servers( try: cookies = api.get_cookies(token) + headers=api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -173,6 +175,7 @@ async def get_operational_mcp_servers( url = f'{api.mcp_base_url}/operational' headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' try: cookies = api.get_cookies(access_token=token, cookies_required=True) @@ -260,7 +263,9 @@ async def create_mcp_market_config( api.login(create_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(create_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {create_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") @@ -290,9 +295,11 @@ async def create_mcp_market_config( 'search': "" } cookies = api.get_cookies(token) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -393,7 +400,9 @@ async def update_mcp_market_config( api.login(update_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(update_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {update_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") From c4461c4917ddfc66e2325994100ac0ef9ee60b4a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 12:27:13 +0800 Subject: [PATCH 022/120] =?UTF-8?q?=E3=80=90add=E3=80=91User=20alias=20ext?= =?UTF-8?q?raction=20and=20retrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../prompt/Problem_Extension_prompt.jinja2 | 52 ++++++++++++++++++ .../prompt/problem_breakdown_prompt.jinja2 | 52 ++++++++++++++++++ .../deduplication/deduped_and_disamb.py | 43 +++++++++++++++ .../prompt/prompts/extract_triplet.jinja2 | 18 ++++++ api/app/repositories/neo4j/cypher_queries.py | 55 +++++++++++++++++++ api/app/repositories/neo4j/graph_search.py | 3 +- api/app/schemas/end_user_schema.py | 3 +- 7 files changed, 224 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 index a0e21fbd..c78cbaac 100644 --- a/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 @@ -39,6 +39,30 @@ 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? +## 指代消歧规则(Coreference Resolution): +在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化: + +1. **"用户"的消歧**: + - "用户是谁?" → 分析历史记录,找出对话发起者的姓名 + - 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人 + - 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?" + +2. **"我"的消歧**: + - "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?" + - 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?" + +3. **"他/她/它"的消歧**: + - 从上下文或历史中找出最近提到的同类实体 + - 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?" + +4. **"那个人/这个人"的消歧**: + - 从历史中找出最近提到的人物 + - 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?" + +5. **优先级**: + - 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人 + - 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象" + 输出要求: @@ -71,6 +95,34 @@ "reason": "输出原问题的关键要素" } ] + +## 指代消歧示例(重要): +示例1 - "用户"的消歧: +输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}] +输入问题:"用户是谁?" +输出: +[ + { + "original_question": "用户是谁?", + "extended_question": "李建国是谁?", + "type": "单跳", + "reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国" + } +] + +示例2 - "我"的消歧: +输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}] +输入问题:"我推荐的书是什么?" +输出: +[ + { + "original_question": "我推荐的书是什么?", + "extended_question": "张曼玉推荐的书是什么?", + "type": "单跳", + "reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉" + } +] + **Output format** **CRITICAL JSON FORMATTING REQUIREMENTS:** 1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes diff --git a/api/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 index aca716a4..ff134ddb 100644 --- a/api/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 @@ -27,6 +27,30 @@ 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? +## 指代消歧规则(Coreference Resolution): +在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化: + +1. **"用户"的消歧**: + - "用户是谁?" → 分析历史记录,找出对话发起者的姓名 + - 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人 + - 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?" + +2. **"我"的消歧**: + - "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?" + - 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?" + +3. **"他/她/它"的消歧**: + - 从上下文或历史中找出最近提到的同类实体 + - 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?" + +4. **"那个人/这个人"的消歧**: + - 从历史中找出最近提到的人物 + - 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?" + +5. **优先级**: + - 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人 + - 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象" + ## 指令: 你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: 单跳(Single-hop) @@ -151,6 +175,34 @@ ] - 必须通过json.loads()的格式支持的形式输出 - 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 + +## 指代消歧示例(重要): +示例1 - "用户"的消歧: +输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}] +输入问题:"用户是谁?" +输出: +[ + { + "id": "Q1", + "question": "李建国是谁?", + "type": "单跳", + "reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国" + } +] + +示例2 - "我"的消歧: +输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}] +输入问题:"我推荐的书是什么?" +输出: +[ + { + "id": "Q1", + "question": "张曼玉推荐的书是什么?", + "type": "单跳", + "reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉" + } +] + - 关键的JSON格式要求 1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号 2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index f2f14d9e..622f6e05 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -203,6 +203,7 @@ def accurate_match( ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: """ 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。 + 同时检测某实体的 name 是否命中另一实体的 aliases,若命中则直接合并。 返回: (deduped_entities, id_redirect, exact_merge_map) """ exact_merge_map: Dict[str, Dict] = {} @@ -240,6 +241,48 @@ def accurate_match( pass deduped_entities = list(canonical_map.values()) + + # 2) 第二轮:检测某实体的 name 是否命中另一实体的 aliases(alias-to-name 精确合并) + # 场景:LLM 把 aliases 中的词(如"齐齐")又单独抽取为独立实体,需在此阶段合并掉 + # 优化:先构建 (end_user_id, alias_lower) -> canonical 的反向索引,查找 O(1) + alias_index: Dict[tuple, ExtractedEntityNode] = {} + for canonical in deduped_entities: + uid = getattr(canonical, "end_user_id", None) + for alias in (getattr(canonical, "aliases", []) or []): + alias_lower = alias.strip().lower() + if alias_lower: + alias_index[(uid, alias_lower)] = canonical + + i = 0 + while i < len(deduped_entities): + ent = deduped_entities[i] + ent_name = (getattr(ent, "name", "") or "").strip().lower() + ent_uid = getattr(ent, "end_user_id", None) + canonical = alias_index.get((ent_uid, ent_name)) + # 确保不是自身 + if canonical is not None and canonical.id != ent.id: + _merge_attribute(canonical, ent) + id_redirect[ent.id] = canonical.id + for k, v in list(id_redirect.items()): + if v == ent.id: + id_redirect[k] = canonical.id + try: + k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" + if k not in exact_merge_map: + exact_merge_map[k] = { + "canonical_id": canonical.id, + "end_user_id": canonical.end_user_id, + "name": canonical.name, + "entity_type": canonical.entity_type, + "merged_ids": set(), + } + exact_merge_map[k]["merged_ids"].add(ent.id) + except Exception: + pass + deduped_entities.pop(i) + else: + i += 1 + return deduped_entities, id_redirect, exact_merge_map def fuzzy_match( diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index b2f287f4..25fffa33 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -96,6 +96,15 @@ The following shows type inheritance relationships (Child → Parent → Grandpa {% endif %} * Include common alternative names, abbreviations and full names * If no aliases exist, use empty array: [] + + **姓名别名识别规则(Name Alias Recognition):** + * 当前对话的用户实体 name 固定为"用户",不得使用用户透露的真实姓名作为 name + * 自我称呼模式:用户说"我的名字是X"、"我叫X" → X 加入 aliases(name 保持为"用户") + * 昵称/小名模式:识别"小名"、"昵称"、"英文名"、"网名"等关键词后的称呼 → 加入 aliases + * 他人称呼模式:识别"同事叫我X"、"朋友叫我X"、"大家叫我X" → X 加入 aliases + * 同一实体的多个称呼应合并到同一 Entity 的 aliases 列表中 + * aliases 中不应包含与 name 完全相同的字符串 + * **严禁将已加入某实体 aliases 的词再单独抽取为另一个独立实体**:若某个词已作为别名归属于"用户"实体,则不得再将该词作为独立 Entity 的 name 出现在 entities 列表中 - Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions - For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value) Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric' @@ -207,6 +216,15 @@ Output: {"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false} ] } + +**Example 4 (姓名别名识别 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人,有多个称呼", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false} + ] +} {% endif %} ===End of Examples=== diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 1f699ad8..f80b7e26 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -336,6 +336,61 @@ ORDER BY score DESC LIMIT $limit """ +SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ +CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score +WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) +RETURN e.id AS id, + e.name AS name, + e.end_user_id AS end_user_id, + e.entity_type AS entity_type, + e.created_at AS created_at, + e.expired_at AS expired_at, + e.entity_idx AS entity_idx, + e.statement_id AS statement_id, + e.description AS description, + e.aliases AS aliases, + e.name_embedding AS name_embedding, + e.connect_strength AS connect_strength, + collect(DISTINCT s.id) AS statement_ids, + collect(DISTINCT c.id) AS chunk_ids, + COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, + COALESCE(e.importance_score, 0.5) AS importance_score, + e.last_access_time AS last_access_time, + COALESCE(e.access_count, 0) AS access_count, + score +UNION +MATCH (e:ExtractedEntity) +WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) + AND e.aliases IS NOT NULL + AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) +RETURN e.id AS id, + e.name AS name, + e.end_user_id AS end_user_id, + e.entity_type AS entity_type, + e.created_at AS created_at, + e.expired_at AS expired_at, + e.entity_idx AS entity_idx, + e.statement_id AS statement_id, + e.description AS description, + e.aliases AS aliases, + e.name_embedding AS name_embedding, + e.connect_strength AS connect_strength, + collect(DISTINCT s.id) AS statement_ids, + collect(DISTINCT c.id) AS chunk_ids, + COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, + COALESCE(e.importance_score, 0.5) AS importance_score, + e.last_access_time AS last_access_time, + COALESCE(e.access_count, 0) AS access_count, + 0.8 AS score +ORDER BY score DESC +LIMIT $limit +""" + + SEARCH_CHUNKS_BY_CONTENT = """ CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index d3aabd32..c5d3bcca 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -13,6 +13,7 @@ from app.repositories.neo4j.cypher_queries import ( SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_ENTITIES_BY_NAME, + SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_KEYWORD, @@ -264,7 +265,7 @@ async def search_graph( if "entities" in include: tasks.append(connector.execute_query( - SEARCH_ENTITIES_BY_NAME, + SEARCH_ENTITIES_BY_NAME_OR_ALIAS, q=q, end_user_id=end_user_id, limit=limit, diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index bbb6fd5c..09671b91 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -1,6 +1,6 @@ import uuid import datetime -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, Field from pydantic import ConfigDict @@ -49,6 +49,7 @@ class EndUserProfileUpdate(BaseModel): """终端用户基本信息更新请求模型""" end_user_id: str = Field(description="终端用户ID") other_name: Optional[str] = Field(description="其他名称", default="") + aliases: Optional[List[str]] = Field(description="别名列表", default=None) position: Optional[str] = Field(description="职位", default=None) department: Optional[str] = Field(description="部门", default=None) contact: Optional[str] = Field(description="联系方式", default=None) From a7285e35ad88dc07677375106c36b0432ac5d9e3 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 15:32:00 +0800 Subject: [PATCH 023/120] =?UTF-8?q?=E3=80=90add=E3=80=91Create=20user=20al?= =?UTF-8?q?ias=20table=20and=20functionality?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/user_memory_controllers.py | 230 ++++++++----- api/app/models/__init__.py | 2 + api/app/models/end_user_model.py | 13 +- api/app/models/user_alias_model.py | 24 ++ api/app/repositories/user_alias_repository.py | 90 ++++++ api/app/schemas/end_user_schema.py | 54 ++-- api/app/schemas/user_alias_schema.py | 33 ++ api/app/services/user_memory_service.py | 301 +++++++++++++++--- 8 files changed, 589 insertions(+), 158 deletions(-) create mode 100644 api/app/models/user_alias_model.py create mode 100644 api/app/repositories/user_alias_repository.py create mode 100644 api/app/schemas/user_alias_schema.py diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index 3ce1df6e..dbdc0a16 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -24,8 +24,9 @@ from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository from app.schemas.end_user_schema import ( - EndUserProfileResponse, - EndUserProfileUpdate, + UserAliasResponse, + UserAliasCreate, + UserAliasUpdate, ) from app.models.end_user_model import EndUser from app.dependencies import get_current_user @@ -336,103 +337,178 @@ async def get_community_graph_data_api( api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e)) +#=======================用户别名及信息接口======================= -@router.get("/read_end_user/profile", response_model=ApiResponse) -async def get_end_user_profile( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), +@router.get("/user_alias", response_model=ApiResponse) +async def get_user_alias( + user_alias_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: - workspace_id = current_user.current_workspace_id - workspace_repo = WorkspaceRepository(db) - workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + """ + 查询用户别名记录 + + 根据 user_alias_id 查询单条用户别名记录。 + """ + workspace_id = current_user.current_workspace_id - if workspace_models: - model_id = workspace_models.get("llm", None) - else: - model_id = None - # 检查用户是否已选择工作空间 if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") + api_logger.warning(f"用户 {current_user.username} 尝试查询用户别名但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"查询用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - try: - # 查询终端用户 - end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() - - if not end_user: - api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") - return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - # 构建响应数据 - profile_data = EndUserProfileResponse( - id=end_user.id, - other_name=end_user.other_name, - position=end_user.position, - department=end_user.department, - contact=end_user.contact, - phone=end_user.phone, - hire_date=end_user.hire_date, - updatetime_profile=end_user.updatetime_profile - ) - - api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}") - return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功") - - except Exception as e: - api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e)) - - -@router.post("/updated_end_user/profile", response_model=ApiResponse) -async def update_end_user_profile( - profile_update: EndUserProfileUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -) -> dict: - """ - 更新终端用户的基本信息 - - 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 - 所有字段都是可选的,只更新提供的字段。 - """ - workspace_id = current_user.current_workspace_id - end_user_id = profile_update.end_user_id - - # 验证工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info( - f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, " - f"workspace={workspace_id}" - ) - - # 调用 Service 层处理业务逻辑 - result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update) + result = user_memory_service.get_user_alias(db, user_alias_id) if result["success"]: - api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}") - return success(data=result["data"], msg="更新成功") + api_logger.info(f"成功查询用户别名: user_alias_id={user_alias_id}") + return success(data=result["data"], msg="查询成功") else: error_msg = result["error"] - api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") + api_logger.error(f"查询用户别名失败: user_alias_id={user_alias_id}, error={error_msg}") + + if error_msg == "用户别名记录不存在": + return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) + elif error_msg == "无效的用户别名记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + else: + return fail(BizCode.INTERNAL_ERROR, "查询用户别名失败", error_msg) - # 根据错误类型映射到合适的业务错误码 + +@router.post("/user_alias/create", response_model=ApiResponse) +async def create_user_alias( + alias_create: UserAliasCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 创建用户别名记录 + + 为指定用户创建一条新的别名记录,支持多个别名。 + """ + workspace_id = current_user.current_workspace_id + end_user_id = alias_create.end_user_id + + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试创建别名但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"创建用户别名请求: end_user_id={end_user_id}, aliases={alias_create.aliases}, " + f"user={current_user.username}, workspace={workspace_id}" + ) + + result = user_memory_service.create_user_alias( + db, end_user_id, alias_create.other_name, alias_create.aliases, alias_create.meta_data + ) + + if result["success"]: + api_logger.info(f"成功创建用户别名: end_user_id={end_user_id}") + return success(data=result["data"], msg="创建成功") + else: + error_msg = result["error"] + api_logger.error(f"用户别名创建失败: end_user_id={end_user_id}, error={error_msg}") + if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) elif error_msg == "无效的用户ID格式": return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg) else: - # 只有未预期的错误才使用 INTERNAL_ERROR - return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) + return fail(BizCode.INTERNAL_ERROR, "用户别名创建失败", error_msg) +@router.post("/user_alias/updated", response_model=ApiResponse) +async def update_user_alias( + alias_update: UserAliasUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 更新用户别名记录 + + 根据 user_alias_id 更新用户别名记录,支持批量更新多个别名。 + + 示例请求体: + { + "user_alias_id": "2d4f57d4-639b-47aa-937a-d461bc2c2d53", + "other_name": "张三1", + "aliases": ["小张", "张工"], + "meta_data": {"position": "工程师", "department": "技术部"} + } + """ + workspace_id = current_user.current_workspace_id + user_alias_id = alias_update.user_alias_id + + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新用户别名但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"更新用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + # 获取更新数据(排除 user_alias_id) + update_data = alias_update.model_dump(exclude_unset=True, exclude={'user_alias_id'}) + + result = user_memory_service.update_user_alias(db, user_alias_id, update_data) + + if result["success"]: + api_logger.info(f"成功更新用户别名: user_alias_id={user_alias_id}") + return success(data=result["data"], msg="更新成功") + else: + error_msg = result["error"] + api_logger.error(f"用户别名更新失败: user_alias_id={user_alias_id}, error={error_msg}") + + if error_msg == "用户别名记录不存在": + return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) + elif error_msg == "无效的用户别名记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + else: + return fail(BizCode.INTERNAL_ERROR, "用户别名更新失败", error_msg) + + +@router.delete("/user_alias", response_model=ApiResponse) +async def delete_user_alias( + user_alias_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 删除用户别名记录 + + 根据 user_alias_id 删除指定的用户别名记录。 + """ + workspace_id = current_user.current_workspace_id + + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试删除别名但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"删除用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + result = user_memory_service.delete_user_alias(db, user_alias_id) + + if result["success"]: + api_logger.info(f"成功删除用户别名: user_alias_id={user_alias_id}") + return success(data=result["data"], msg="删除成功") + else: + error_msg = result["error"] + api_logger.error(f"用户别名删除失败: user_alias_id={user_alias_id}, error={error_msg}") + + if error_msg == "用户别名记录不存在": + return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) + elif error_msg == "无效的用户别名记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + else: + return fail(BizCode.INTERNAL_ERROR, "用户别名删除失败", error_msg) + @router.get("/memory_space/timeline_memories", response_model=ApiResponse) async def memory_space_timeline_of_shared_memories( id: str, label: str, diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index c6098a6d..22dd4851 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -16,6 +16,7 @@ from .agent_app_config_model import AgentConfig from .app_release_model import AppRelease from .memory_increment_model import MemoryIncrement from .end_user_model import EndUser +from .user_alias_model import UserAlias from .appshare_model import AppShare from .release_share_model import ReleaseShare from .conversation_model import Conversation, Message @@ -60,6 +61,7 @@ __all__ = [ "AppRelease", "MemoryIncrement", "EndUser", + "UserAlias", "AppShare", "ReleaseShare", "Conversation", diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index 60600fcf..a30e1dcb 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -30,14 +30,6 @@ class EndUser(Base): comment="关联的记忆配置ID" ) - # 用户基本信息字段 - position = Column(String, nullable=True, comment="职位") - department = Column(String, nullable=True, comment="部门") - contact = Column(String, nullable=True, comment="联系方式") - phone = Column(String, nullable=True, comment="电话") - hire_date = Column(DateTime, nullable=True, comment="入职日期") - updatetime_profile = Column(DateTime, nullable=True, comment="核心档案信息最后更新时间") - # 用户摘要四个维度 - User Summary Four Dimensions user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)") personality_traits = Column(Text, nullable=True, comment="性格特点") @@ -65,4 +57,7 @@ class EndUser(Base): ) # 与 WorkSpace 的反向关系 - workspace = relationship("Workspace", back_populates="end_users") \ No newline at end of file + workspace = relationship("Workspace", back_populates="end_users") + + # 与 UserAlias 的反向关系 + aliases = relationship("UserAlias", back_populates="end_user", cascade="all, delete-orphan") \ No newline at end of file diff --git a/api/app/models/user_alias_model.py b/api/app/models/user_alias_model.py new file mode 100644 index 00000000..ad862ead --- /dev/null +++ b/api/app/models/user_alias_model.py @@ -0,0 +1,24 @@ +import datetime +import uuid + +from sqlalchemy import Column, DateTime, ForeignKey, String, Text +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.orm import relationship + +from app.db import Base + + +class UserAlias(Base): + """用户别名表 - 存储用户的别名信息""" + __tablename__ = "user_aliases" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) + end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID") + other_name = Column(String, nullable=False, comment="关联的用户名称") + aliases = Column(JSONB, nullable=True, comment="用户别名列表(JSON数组)") + meta_data = Column(JSONB, nullable=True, comment="用户相关的扩展信息(JSON格式)") + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + # 与 EndUser 的关系 + end_user = relationship("EndUser", back_populates="aliases") diff --git a/api/app/repositories/user_alias_repository.py b/api/app/repositories/user_alias_repository.py new file mode 100644 index 00000000..96f8a778 --- /dev/null +++ b/api/app/repositories/user_alias_repository.py @@ -0,0 +1,90 @@ +""" +用户别名仓储层 +""" +import uuid +from typing import List, Optional +from sqlalchemy.orm import Session + +from app.models.user_alias_model import UserAlias +from app.core.logging_config import get_logger + +logger = get_logger(__name__) + + +class UserAliasRepository: + """用户别名仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: uuid.UUID, other_name: str, alias: str = None, meta_data: dict = None) -> UserAlias: + """创建用户别名""" + user_alias = UserAlias( + end_user_id=end_user_id, + other_name=other_name, + alias=alias, + meta_data=meta_data + ) + self.db.add(user_alias) + self.db.commit() + self.db.refresh(user_alias) + logger.info(f"创建用户别名: end_user_id={end_user_id}, alias={alias}") + return user_alias + + def get_by_id(self, alias_id: uuid.UUID) -> Optional[UserAlias]: + """根据ID获取别名""" + return self.db.query(UserAlias).filter(UserAlias.id == alias_id).first() + + def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[UserAlias]: + """获取用户的所有别名""" + return self.db.query(UserAlias).filter(UserAlias.end_user_id == end_user_id).all() + + def update(self, alias_id: uuid.UUID, alias: str = None, meta_data: dict = None) -> Optional[UserAlias]: + """更新别名""" + user_alias = self.get_by_id(alias_id) + if user_alias: + if alias is not None: + user_alias.alias = alias + if meta_data is not None: + user_alias.meta_data = meta_data + self.db.commit() + self.db.refresh(user_alias) + logger.info(f"更新用户别名: alias_id={alias_id}") + return user_alias + + def delete(self, alias_id: uuid.UUID) -> bool: + """删除别名""" + user_alias = self.get_by_id(alias_id) + if user_alias: + self.db.delete(user_alias) + self.db.commit() + logger.info(f"删除用户别名: alias_id={alias_id}") + return True + return False + + def delete_by_end_user_id(self, end_user_id: uuid.UUID) -> int: + """删除用户的所有别名""" + count = self.db.query(UserAlias).filter(UserAlias.end_user_id == end_user_id).delete() + self.db.commit() + logger.info(f"删除用户所有别名: end_user_id={end_user_id}, count={count}") + return count + + def batch_create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str]) -> List[UserAlias]: + """批量创建别名""" + user_aliases = [] + for alias in aliases: + if alias and alias.strip(): + user_alias = UserAlias( + end_user_id=end_user_id, + other_name=other_name, + alias=alias.strip() + ) + self.db.add(user_alias) + user_aliases.append(user_alias) + + self.db.commit() + for user_alias in user_aliases: + self.db.refresh(user_alias) + + logger.info(f"批量创建用户别名: end_user_id={end_user_id}, count={len(user_aliases)}") + return user_aliases diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 09671b91..d541ba47 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -17,41 +17,35 @@ class EndUser(BaseModel): created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) - # 用户基本信息字段 - position: Optional[str] = Field(description="职位", default=None) - department: Optional[str] = Field(description="部门", default=None) - contact: Optional[str] = Field(description="联系方式", default=None) - phone: Optional[str] = Field(description="电话", default=None) - hire_date: Optional[datetime.datetime] = Field(description="入职日期", default=None) - updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None) - # 用户摘要和洞察更新时间 user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None) memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) -class EndUserProfileResponse(BaseModel): - """终端用户基本信息响应模型""" +class UserAliasResponse(BaseModel): + """用户别名响应模型""" model_config = ConfigDict(from_attributes=True) - id: uuid.UUID = Field(description="终端用户ID") - other_name: Optional[str] = Field(description="其他名称", default="") - position: Optional[str] = Field(description="职位", default=None) - department: Optional[str] = Field(description="部门", default=None) - contact: Optional[str] = Field(description="联系方式", default=None) - phone: Optional[str] = Field(description="电话", default=None) - hire_date: Optional[datetime.datetime] = Field(description="入职日期", default=None) - updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None) - - - -class EndUserProfileUpdate(BaseModel): - """终端用户基本信息更新请求模型""" - end_user_id: str = Field(description="终端用户ID") - other_name: Optional[str] = Field(description="其他名称", default="") + user_alias_id: uuid.UUID = Field(description="用户别名记录ID") + end_user_id: uuid.UUID = Field(description="终端用户ID") + other_name: str = Field(description="用户名称") aliases: Optional[List[str]] = Field(description="别名列表", default=None) - position: Optional[str] = Field(description="职位", default=None) - department: Optional[str] = Field(description="部门", default=None) - contact: Optional[str] = Field(description="联系方式", default=None) - phone: Optional[str] = Field(description="电话", default=None) - hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) \ No newline at end of file + meta_data: Optional[dict] = Field(description="扩展信息", default=None) + created_at: datetime.datetime = Field(description="创建时间") + updated_at: datetime.datetime = Field(description="更新时间") + + +class UserAliasCreate(BaseModel): + """创建用户别名请求模型""" + end_user_id: str = Field(description="终端用户ID") + other_name: str = Field(description="用户名称") + aliases: Optional[List[str]] = Field(description="别名列表", default=None) + meta_data: Optional[dict] = Field(description="扩展信息", default=None) + + +class UserAliasUpdate(BaseModel): + """更新用户别名请求模型""" + user_alias_id: str = Field(description="用户别名记录ID") + other_name: Optional[str] = Field(description="用户名称", default=None) + aliases: Optional[List[str]] = Field(description="别名列表", default=None) + meta_data: Optional[dict] = Field(description="扩展信息", default=None) \ No newline at end of file diff --git a/api/app/schemas/user_alias_schema.py b/api/app/schemas/user_alias_schema.py new file mode 100644 index 00000000..847c5c5d --- /dev/null +++ b/api/app/schemas/user_alias_schema.py @@ -0,0 +1,33 @@ +import uuid +import datetime +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field +from pydantic import ConfigDict + + +class UserAliasBase(BaseModel): + """用户别名基础模型""" + other_name: str = Field(description="关联的用户名称") + alias: Optional[str] = Field(description="用户别名", default=None) + meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) + + +class UserAliasCreate(UserAliasBase): + """创建用户别名请求模型""" + end_user_id: uuid.UUID = Field(description="关联的终端用户ID") + + +class UserAliasUpdate(BaseModel): + """更新用户别名请求模型""" + alias: Optional[str] = Field(description="用户别名", default=None) + meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) + + +class UserAliasResponse(UserAliasBase): + """用户别名响应模型""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID = Field(description="别名ID") + end_user_id: uuid.UUID = Field(description="关联的终端用户ID") + created_at: datetime.datetime = Field(description="创建时间") + updated_at: datetime.datetime = Field(description="更新时间") diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 585fdd78..5aa589e8 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -361,29 +361,105 @@ class UserMemoryService: if hasattr(original_value, 'timestamp'): data[key] = UserMemoryService._datetime_to_timestamp(original_value) return data - - def update_end_user_profile( + # ======================== 用户别名及信息 ======================== + def get_user_alias( self, db: Session, - end_user_id: str, - profile_update: Any + user_alias_id: str ) -> Dict[str, Any]: """ - 更新终端用户的基本信息 + 查询单个用户别名记录 Args: db: 数据库会话 - end_user_id: 终端用户ID (UUID) - profile_update: 包含更新字段的 Pydantic 模型 + user_alias_id: 用户别名记录ID (UUID) Returns: { "success": bool, - "data": dict, # 更新后的用户档案数据 + "data": dict, "error": Optional[str] } """ try: + from app.models.user_alias_model import UserAlias + + # 转换为UUID并查询 + alias_uuid = uuid.UUID(user_alias_id) + user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + + if not user_alias_record: + logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + return { + "success": False, + "data": None, + "error": "用户别名记录不存在" + } + + # 构建响应数据 + from app.schemas.end_user_schema import UserAliasResponse + response_data = UserAliasResponse( + user_alias_id=user_alias_record.id, + end_user_id=user_alias_record.end_user_id, + other_name=user_alias_record.other_name, + aliases=user_alias_record.aliases, + meta_data=user_alias_record.meta_data, + created_at=user_alias_record.created_at, + updated_at=user_alias_record.updated_at + ) + + logger.info(f"成功查询用户别名记录: user_alias_id={user_alias_id}") + + return { + "success": True, + "data": response_data.model_dump(), + "error": None + } + + except ValueError: + logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + return { + "success": False, + "data": None, + "error": "无效的用户别名记录ID格式" + } + except Exception as e: + logger.error(f"查询用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + + def create_user_alias( + self, + db: Session, + end_user_id: str, + other_name: str, + aliases: List[str] = None, + meta_data: dict = None + ) -> Dict[str, Any]: + """ + 创建用户别名记录 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + other_name: 用户名称 + aliases: 别名列表 + meta_data: 扩展信息 + + Returns: + { + "success": bool, + "data": dict, + "error": Optional[str] + } + """ + try: + from app.models.user_alias_model import UserAlias + from app.repositories.end_user_repository import EndUserRepository + # 转换为UUID并查询用户 user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) @@ -397,47 +473,34 @@ class UserMemoryService: "error": "终端用户不存在" } - # 获取更新数据(排除 end_user_id 字段) - update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) - - # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime - if 'hire_date' in update_data: - hire_date_timestamp = update_data['hire_date'] - if hire_date_timestamp is not None: - from app.core.api_key_utils import timestamp_to_datetime - update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) - # 如果是 None,保持 None(允许清空) - - # 更新字段 - for field, value in update_data.items(): - setattr(end_user, field, value) - - # 更新时间戳 - end_user.updated_at = datetime.now() - end_user.updatetime_profile = datetime.now() - - # 提交更改 + # 创建新的别名记录 + new_alias = UserAlias( + end_user_id=user_uuid, + other_name=other_name, + aliases=aliases, + meta_data=meta_data + ) + db.add(new_alias) db.commit() - db.refresh(end_user) + db.refresh(new_alias) # 构建响应数据 - from app.schemas.end_user_schema import EndUserProfileResponse - profile_data = EndUserProfileResponse( - id=end_user.id, - other_name=end_user.other_name, - position=end_user.position, - department=end_user.department, - contact=end_user.contact, - phone=end_user.phone, - hire_date=end_user.hire_date, - updatetime_profile=end_user.updatetime_profile + from app.schemas.end_user_schema import UserAliasResponse + response_data = UserAliasResponse( + user_alias_id=new_alias.id, + end_user_id=new_alias.end_user_id, + other_name=new_alias.other_name, + aliases=new_alias.aliases, + meta_data=new_alias.meta_data, + created_at=new_alias.created_at, + updated_at=new_alias.updated_at ) - logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") + logger.info(f"成功创建用户别名记录: end_user_id={end_user_id}") return { "success": True, - "data": self.convert_profile_to_dict_with_timestamp(profile_data), + "data": response_data.model_dump(), "error": None } @@ -450,7 +513,161 @@ class UserMemoryService: } except Exception as e: db.rollback() - logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") + logger.error(f"创建用户别名记录失败: end_user_id={end_user_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + + def update_user_alias( + self, + db: Session, + user_alias_id: str, + update_data: Dict[str, Any] + ) -> Dict[str, Any]: + """ + 更新用户别名记录 + + Args: + db: 数据库会话 + user_alias_id: 用户别名记录ID (UUID) + update_data: 更新数据字典 + + Returns: + { + "success": bool, + "data": dict, + "error": Optional[str] + } + """ + try: + from app.models.user_alias_model import UserAlias + + # 转换为UUID并查询 + alias_uuid = uuid.UUID(user_alias_id) + user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + + if not user_alias_record: + logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + return { + "success": False, + "data": None, + "error": "用户别名记录不存在" + } + + # 更新字段 + for field, value in update_data.items(): + if hasattr(user_alias_record, field) and field != 'user_alias_id': + setattr(user_alias_record, field, value) + + # 更新时间戳 + user_alias_record.updated_at = datetime.now() + + # 提交更改 + db.commit() + db.refresh(user_alias_record) + + # 构建响应数据 + from app.schemas.end_user_schema import UserAliasResponse + response_data = UserAliasResponse( + user_alias_id=user_alias_record.id, + end_user_id=user_alias_record.end_user_id, + other_name=user_alias_record.other_name, + aliases=user_alias_record.aliases, + meta_data=user_alias_record.meta_data, + created_at=user_alias_record.created_at, + updated_at=user_alias_record.updated_at + ) + + logger.info(f"成功更新用户别名记录: user_alias_id={user_alias_id}, updated_fields={list(update_data.keys())}") + + return { + "success": True, + "data": response_data.model_dump(), + "error": None + } + + except ValueError: + logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + return { + "success": False, + "data": None, + "error": "无效的用户别名记录ID格式" + } + except Exception as e: + db.rollback() + logger.error(f"更新用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + + def delete_user_alias( + self, + db: Session, + user_alias_id: str + ) -> Dict[str, Any]: + """ + 删除用户别名记录 + + Args: + db: 数据库会话 + user_alias_id: 用户别名记录ID (UUID) + + Returns: + { + "success": bool, + "data": dict, + "error": Optional[str] + } + """ + try: + from app.models.user_alias_model import UserAlias + + # 转换为UUID并查询 + alias_uuid = uuid.UUID(user_alias_id) + user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + + if not user_alias_record: + logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + return { + "success": False, + "data": None, + "error": "用户别名记录不存在" + } + + # 删除记录 + db.delete(user_alias_record) + db.commit() + + logger.info(f"成功删除用户别名记录: user_alias_id={user_alias_id}") + + return { + "success": True, + "data": {"user_alias_id": user_alias_id}, + "error": None + } + + except ValueError: + logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + return { + "success": False, + "data": None, + "error": "无效的用户别名记录ID格式" + } + except Exception as e: + db.rollback() + logger.error(f"删除用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + except Exception as e: + db.rollback() + logger.error(f"用户别名记录更新失败: user_alias_id={user_alias_id}, error={str(e)}") return { "success": False, "data": None, From e8d575fd0b8508e30566c9b28d9b5ce797c0f628 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 16:10:23 +0800 Subject: [PATCH 024/120] [add] Separate the definitions of end_user and user_alias --- .../controllers/user_memory_controllers.py | 2 +- api/app/schemas/end_user_schema.py | 31 +------------------ api/app/schemas/user_alias_schema.py | 12 ++++--- api/app/services/user_memory_service.py | 13 +++++--- 4 files changed, 17 insertions(+), 41 deletions(-) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index dbdc0a16..d6b910a3 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -23,7 +23,7 @@ from app.services.memory_entity_relationship_service import MemoryEntityService, from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository -from app.schemas.end_user_schema import ( +from app.schemas.user_alias_schema import ( UserAliasResponse, UserAliasCreate, UserAliasUpdate, diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index d541ba47..c2498203 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -19,33 +19,4 @@ class EndUser(BaseModel): # 用户摘要和洞察更新时间 user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None) - memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) - - -class UserAliasResponse(BaseModel): - """用户别名响应模型""" - model_config = ConfigDict(from_attributes=True) - - user_alias_id: uuid.UUID = Field(description="用户别名记录ID") - end_user_id: uuid.UUID = Field(description="终端用户ID") - other_name: str = Field(description="用户名称") - aliases: Optional[List[str]] = Field(description="别名列表", default=None) - meta_data: Optional[dict] = Field(description="扩展信息", default=None) - created_at: datetime.datetime = Field(description="创建时间") - updated_at: datetime.datetime = Field(description="更新时间") - - -class UserAliasCreate(BaseModel): - """创建用户别名请求模型""" - end_user_id: str = Field(description="终端用户ID") - other_name: str = Field(description="用户名称") - aliases: Optional[List[str]] = Field(description="别名列表", default=None) - meta_data: Optional[dict] = Field(description="扩展信息", default=None) - - -class UserAliasUpdate(BaseModel): - """更新用户别名请求模型""" - user_alias_id: str = Field(description="用户别名记录ID") - other_name: Optional[str] = Field(description="用户名称", default=None) - aliases: Optional[List[str]] = Field(description="别名列表", default=None) - meta_data: Optional[dict] = Field(description="扩展信息", default=None) \ No newline at end of file + memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) \ No newline at end of file diff --git a/api/app/schemas/user_alias_schema.py b/api/app/schemas/user_alias_schema.py index 847c5c5d..a8bf7700 100644 --- a/api/app/schemas/user_alias_schema.py +++ b/api/app/schemas/user_alias_schema.py @@ -1,6 +1,6 @@ import uuid import datetime -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from pydantic import BaseModel, Field from pydantic import ConfigDict @@ -8,18 +8,20 @@ from pydantic import ConfigDict class UserAliasBase(BaseModel): """用户别名基础模型""" other_name: str = Field(description="关联的用户名称") - alias: Optional[str] = Field(description="用户别名", default=None) + aliases: Optional[List[str]] = Field(description="用户别名列表", default=None) meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) class UserAliasCreate(UserAliasBase): """创建用户别名请求模型""" - end_user_id: uuid.UUID = Field(description="关联的终端用户ID") + end_user_id: str = Field(description="关联的终端用户ID") class UserAliasUpdate(BaseModel): """更新用户别名请求模型""" - alias: Optional[str] = Field(description="用户别名", default=None) + user_alias_id: str = Field(description="用户别名记录ID") + other_name: Optional[str] = Field(description="用户名称", default=None) + aliases: Optional[List[str]] = Field(description="用户别名列表", default=None) meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) @@ -27,7 +29,7 @@ class UserAliasResponse(UserAliasBase): """用户别名响应模型""" model_config = ConfigDict(from_attributes=True) - id: uuid.UUID = Field(description="别名ID") + user_alias_id: uuid.UUID = Field(description="用户别名记录ID") end_user_id: uuid.UUID = Field(description="关联的终端用户ID") created_at: datetime.datetime = Field(description="创建时间") updated_at: datetime.datetime = Field(description="更新时间") diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 5aa589e8..0a01b6dc 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -397,7 +397,7 @@ class UserMemoryService: } # 构建响应数据 - from app.schemas.end_user_schema import UserAliasResponse + from app.schemas.user_alias_schema import UserAliasResponse response_data = UserAliasResponse( user_alias_id=user_alias_record.id, end_user_id=user_alias_record.end_user_id, @@ -485,7 +485,7 @@ class UserMemoryService: db.refresh(new_alias) # 构建响应数据 - from app.schemas.end_user_schema import UserAliasResponse + from app.schemas.user_alias_schema import UserAliasResponse response_data = UserAliasResponse( user_alias_id=new_alias.id, end_user_id=new_alias.end_user_id, @@ -556,9 +556,12 @@ class UserMemoryService: "error": "用户别名记录不存在" } - # 更新字段 + # 定义允许更新的字段白名单 + allowed_fields = {'other_name', 'aliases', 'meta_data'} + + # 更新字段(仅允许白名单中的字段) for field, value in update_data.items(): - if hasattr(user_alias_record, field) and field != 'user_alias_id': + if field in allowed_fields: setattr(user_alias_record, field, value) # 更新时间戳 @@ -569,7 +572,7 @@ class UserMemoryService: db.refresh(user_alias_record) # 构建响应数据 - from app.schemas.end_user_schema import UserAliasResponse + from app.schemas.user_alias_schema import UserAliasResponse response_data = UserAliasResponse( user_alias_id=user_alias_record.id, end_user_id=user_alias_record.end_user_id, From db14d40fb39f71dc68bf3fd8e7ea2a908e94b677 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 17:30:47 +0800 Subject: [PATCH 025/120] =?UTF-8?q?[changes]=20user=5Falias=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=B8=BAend=5Fuser=5Finfo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/user_memory_controllers.py | 134 +++++++------- api/app/models/__init__.py | 2 +- ..._alias_model.py => end_user_info_model.py} | 8 +- api/app/models/end_user_model.py | 4 +- .../repositories/end_user_info_repository.py | 90 ++++++++++ api/app/repositories/user_alias_repository.py | 90 ---------- ...lias_schema.py => end_user_info_schema.py} | 20 +-- api/app/services/user_memory_service.py | 166 +++++++++--------- 8 files changed, 257 insertions(+), 257 deletions(-) rename api/app/models/{user_alias_model.py => end_user_info_model.py} (83%) create mode 100644 api/app/repositories/end_user_info_repository.py delete mode 100644 api/app/repositories/user_alias_repository.py rename api/app/schemas/{user_alias_schema.py => end_user_info_schema.py} (67%) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index d6b910a3..b4c33032 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -23,10 +23,10 @@ from app.services.memory_entity_relationship_service import MemoryEntityService, from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository -from app.schemas.user_alias_schema import ( - UserAliasResponse, - UserAliasCreate, - UserAliasUpdate, +from app.schemas.end_user_info_schema import ( + EndUserInfoResponse, + EndUserInfoCreate, + EndUserInfoUpdate, ) from app.models.end_user_model import EndUser from app.dependencies import get_current_user @@ -337,177 +337,177 @@ async def get_community_graph_data_api( api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e)) -#=======================用户别名及信息接口======================= +#=======================终端用户信息接口======================= -@router.get("/user_alias", response_model=ApiResponse) -async def get_user_alias( - user_alias_id: str, +@router.get("/end_user_info", response_model=ApiResponse) +async def get_end_user_info( + end_user_info_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """ - 查询用户别名记录 + 查询终端用户信息记录 - 根据 user_alias_id 查询单条用户别名记录。 + 根据 end_user_info_id 查询单条终端用户信息记录。 """ workspace_id = current_user.current_workspace_id if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试查询用户别名但未选择工作空间") + api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"查询用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " + f"查询终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - result = user_memory_service.get_user_alias(db, user_alias_id) + result = user_memory_service.get_end_user_info(db, end_user_info_id) if result["success"]: - api_logger.info(f"成功查询用户别名: user_alias_id={user_alias_id}") + api_logger.info(f"成功查询终端用户信息: end_user_info_id={end_user_info_id}") return success(data=result["data"], msg="查询成功") else: error_msg = result["error"] - api_logger.error(f"查询用户别名失败: user_alias_id={user_alias_id}, error={error_msg}") + api_logger.error(f"查询终端用户信息失败: end_user_info_id={end_user_info_id}, error={error_msg}") - if error_msg == "用户别名记录不存在": - return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) - elif error_msg == "无效的用户别名记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + if error_msg == "终端用户信息记录不存在": + return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) + elif error_msg == "无效的终端用户信息记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) else: - return fail(BizCode.INTERNAL_ERROR, "查询用户别名失败", error_msg) + return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg) -@router.post("/user_alias/create", response_model=ApiResponse) -async def create_user_alias( - alias_create: UserAliasCreate, +@router.post("/end_user_info/create", response_model=ApiResponse) +async def create_end_user_info( + info_create: EndUserInfoCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """ - 创建用户别名记录 + 创建终端用户信息记录 - 为指定用户创建一条新的别名记录,支持多个别名。 + 为指定用户创建一条新的信息记录,支持多个别名。 """ workspace_id = current_user.current_workspace_id - end_user_id = alias_create.end_user_id + end_user_id = info_create.end_user_id if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试创建别名但未选择工作空间") + api_logger.warning(f"用户 {current_user.username} 尝试创建终端用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"创建用户别名请求: end_user_id={end_user_id}, aliases={alias_create.aliases}, " + f"创建终端用户信息请求: end_user_id={end_user_id}, aliases={info_create.aliases}, " f"user={current_user.username}, workspace={workspace_id}" ) - result = user_memory_service.create_user_alias( - db, end_user_id, alias_create.other_name, alias_create.aliases, alias_create.meta_data + result = user_memory_service.create_end_user_info( + db, end_user_id, info_create.other_name, info_create.aliases, info_create.meta_data ) if result["success"]: - api_logger.info(f"成功创建用户别名: end_user_id={end_user_id}") + api_logger.info(f"成功创建终端用户信息: end_user_id={end_user_id}") return success(data=result["data"], msg="创建成功") else: error_msg = result["error"] - api_logger.error(f"用户别名创建失败: end_user_id={end_user_id}, error={error_msg}") + api_logger.error(f"终端用户信息创建失败: end_user_id={end_user_id}, error={error_msg}") if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) elif error_msg == "无效的用户ID格式": return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg) else: - return fail(BizCode.INTERNAL_ERROR, "用户别名创建失败", error_msg) + return fail(BizCode.INTERNAL_ERROR, "终端用户信息创建失败", error_msg) -@router.post("/user_alias/updated", response_model=ApiResponse) -async def update_user_alias( - alias_update: UserAliasUpdate, +@router.post("/end_user_info/updated", response_model=ApiResponse) +async def update_end_user_info( + info_update: EndUserInfoUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """ - 更新用户别名记录 + 更新终端用户信息记录 - 根据 user_alias_id 更新用户别名记录,支持批量更新多个别名。 + 根据 end_user_info_id 更新终端用户信息记录,支持批量更新多个别名。 示例请求体: { - "user_alias_id": "2d4f57d4-639b-47aa-937a-d461bc2c2d53", + "end_user_info_id": "2d4f57d4-639b-47aa-937a-d461bc2c2d53", "other_name": "张三1", "aliases": ["小张", "张工"], "meta_data": {"position": "工程师", "department": "技术部"} } """ workspace_id = current_user.current_workspace_id - user_alias_id = alias_update.user_alias_id + end_user_info_id = info_update.end_user_info_id if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新用户别名但未选择工作空间") + api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"更新用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " + f"更新终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - # 获取更新数据(排除 user_alias_id) - update_data = alias_update.model_dump(exclude_unset=True, exclude={'user_alias_id'}) + # 获取更新数据(排除 end_user_info_id) + update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_info_id'}) - result = user_memory_service.update_user_alias(db, user_alias_id, update_data) + result = user_memory_service.update_end_user_info(db, end_user_info_id, update_data) if result["success"]: - api_logger.info(f"成功更新用户别名: user_alias_id={user_alias_id}") + api_logger.info(f"成功更新终端用户信息: end_user_info_id={end_user_info_id}") return success(data=result["data"], msg="更新成功") else: error_msg = result["error"] - api_logger.error(f"用户别名更新失败: user_alias_id={user_alias_id}, error={error_msg}") + api_logger.error(f"终端用户信息更新失败: end_user_info_id={end_user_info_id}, error={error_msg}") - if error_msg == "用户别名记录不存在": - return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) - elif error_msg == "无效的用户别名记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + if error_msg == "终端用户信息记录不存在": + return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) + elif error_msg == "无效的终端用户信息记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) else: - return fail(BizCode.INTERNAL_ERROR, "用户别名更新失败", error_msg) + return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg) -@router.delete("/user_alias", response_model=ApiResponse) -async def delete_user_alias( - user_alias_id: str, +@router.delete("/end_user_info", response_model=ApiResponse) +async def delete_end_user_info( + end_user_info_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """ - 删除用户别名记录 + 删除终端用户信息记录 - 根据 user_alias_id 删除指定的用户别名记录。 + 根据 end_user_info_id 删除指定的终端用户信息记录。 """ workspace_id = current_user.current_workspace_id if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试删除别名但未选择工作空间") + api_logger.warning(f"用户 {current_user.username} 尝试删除终端用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"删除用户别名请求: user_alias_id={user_alias_id}, user={current_user.username}, " + f"删除终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - result = user_memory_service.delete_user_alias(db, user_alias_id) + result = user_memory_service.delete_end_user_info(db, end_user_info_id) if result["success"]: - api_logger.info(f"成功删除用户别名: user_alias_id={user_alias_id}") + api_logger.info(f"成功删除终端用户信息: end_user_info_id={end_user_info_id}") return success(data=result["data"], msg="删除成功") else: error_msg = result["error"] - api_logger.error(f"用户别名删除失败: user_alias_id={user_alias_id}, error={error_msg}") + api_logger.error(f"终端用户信息删除失败: end_user_info_id={end_user_info_id}, error={error_msg}") - if error_msg == "用户别名记录不存在": - return fail(BizCode.USER_NOT_FOUND, "用户别名记录不存在", error_msg) - elif error_msg == "无效的用户别名记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的用户别名记录ID格式", error_msg) + if error_msg == "终端用户信息记录不存在": + return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) + elif error_msg == "无效的终端用户信息记录ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) else: - return fail(BizCode.INTERNAL_ERROR, "用户别名删除失败", error_msg) + return fail(BizCode.INTERNAL_ERROR, "终端用户信息删除失败", error_msg) @router.get("/memory_space/timeline_memories", response_model=ApiResponse) async def memory_space_timeline_of_shared_memories( diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 22dd4851..7dd26d34 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -16,7 +16,7 @@ from .agent_app_config_model import AgentConfig from .app_release_model import AppRelease from .memory_increment_model import MemoryIncrement from .end_user_model import EndUser -from .user_alias_model import UserAlias +from .end_user_info_model import EndUserInfo from .appshare_model import AppShare from .release_share_model import ReleaseShare from .conversation_model import Conversation, Message diff --git a/api/app/models/user_alias_model.py b/api/app/models/end_user_info_model.py similarity index 83% rename from api/app/models/user_alias_model.py rename to api/app/models/end_user_info_model.py index ad862ead..ed747002 100644 --- a/api/app/models/user_alias_model.py +++ b/api/app/models/end_user_info_model.py @@ -8,9 +8,9 @@ from sqlalchemy.orm import relationship from app.db import Base -class UserAlias(Base): - """用户别名表 - 存储用户的别名信息""" - __tablename__ = "user_aliases" +class EndUserInfo(Base): + """终端用户信息表 - 存储用户的别名和扩展信息""" + __tablename__ = "end_user_info" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID") @@ -21,4 +21,4 @@ class UserAlias(Base): updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") # 与 EndUser 的关系 - end_user = relationship("EndUser", back_populates="aliases") + end_user = relationship("EndUser", back_populates="info") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index a30e1dcb..a821680f 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -59,5 +59,5 @@ class EndUser(Base): # 与 WorkSpace 的反向关系 workspace = relationship("Workspace", back_populates="end_users") - # 与 UserAlias 的反向关系 - aliases = relationship("UserAlias", back_populates="end_user", cascade="all, delete-orphan") \ No newline at end of file + # 与 EndUserInfo 的反向关系 + info = relationship("EndUserInfo", back_populates="end_user", cascade="all, delete-orphan") \ No newline at end of file diff --git a/api/app/repositories/end_user_info_repository.py b/api/app/repositories/end_user_info_repository.py new file mode 100644 index 00000000..ee05d12d --- /dev/null +++ b/api/app/repositories/end_user_info_repository.py @@ -0,0 +1,90 @@ +""" +终端用户信息仓储层 +""" +import uuid +from typing import List, Optional +from sqlalchemy.orm import Session + +from app.models.end_user_info_model import EndUserInfo +from app.core.logging_config import get_logger + +logger = get_logger(__name__) + + +class EndUserInfoRepository: + """终端用户信息仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: uuid.UUID, other_name: str, alias: str = None, meta_data: dict = None) -> EndUserInfo: + """创建终端用户信息""" + end_user_info = EndUserInfo( + end_user_id=end_user_id, + other_name=other_name, + alias=alias, + meta_data=meta_data + ) + self.db.add(end_user_info) + self.db.commit() + self.db.refresh(end_user_info) + logger.info(f"创建终端用户信息: end_user_id={end_user_id}, alias={alias}") + return end_user_info + + def get_by_id(self, info_id: uuid.UUID) -> Optional[EndUserInfo]: + """根据ID获取用户信息""" + return self.db.query(EndUserInfo).filter(EndUserInfo.id == info_id).first() + + def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[EndUserInfo]: + """获取用户的所有信息记录""" + return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all() + + def update(self, info_id: uuid.UUID, alias: str = None, meta_data: dict = None) -> Optional[EndUserInfo]: + """更新用户信息""" + end_user_info = self.get_by_id(info_id) + if end_user_info: + if alias is not None: + end_user_info.alias = alias + if meta_data is not None: + end_user_info.meta_data = meta_data + self.db.commit() + self.db.refresh(end_user_info) + logger.info(f"更新终端用户信息: info_id={info_id}") + return end_user_info + + def delete(self, info_id: uuid.UUID) -> bool: + """删除用户信息""" + end_user_info = self.get_by_id(info_id) + if end_user_info: + self.db.delete(end_user_info) + self.db.commit() + logger.info(f"删除终端用户信息: info_id={info_id}") + return True + return False + + def delete_by_end_user_id(self, end_user_id: uuid.UUID) -> int: + """删除用户的所有信息记录""" + count = self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).delete() + self.db.commit() + logger.info(f"删除用户所有信息记录: end_user_id={end_user_id}, count={count}") + return count + + def batch_create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str]) -> List[EndUserInfo]: + """批量创建用户信息""" + end_user_infos = [] + for alias in aliases: + if alias and alias.strip(): + end_user_info = EndUserInfo( + end_user_id=end_user_id, + other_name=other_name, + alias=alias.strip() + ) + self.db.add(end_user_info) + end_user_infos.append(end_user_info) + + self.db.commit() + for end_user_info in end_user_infos: + self.db.refresh(end_user_info) + + logger.info(f"批量创建终端用户信息: end_user_id={end_user_id}, count={len(end_user_infos)}") + return end_user_infos diff --git a/api/app/repositories/user_alias_repository.py b/api/app/repositories/user_alias_repository.py deleted file mode 100644 index 96f8a778..00000000 --- a/api/app/repositories/user_alias_repository.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -用户别名仓储层 -""" -import uuid -from typing import List, Optional -from sqlalchemy.orm import Session - -from app.models.user_alias_model import UserAlias -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - - -class UserAliasRepository: - """用户别名仓储类""" - - def __init__(self, db: Session): - self.db = db - - def create(self, end_user_id: uuid.UUID, other_name: str, alias: str = None, meta_data: dict = None) -> UserAlias: - """创建用户别名""" - user_alias = UserAlias( - end_user_id=end_user_id, - other_name=other_name, - alias=alias, - meta_data=meta_data - ) - self.db.add(user_alias) - self.db.commit() - self.db.refresh(user_alias) - logger.info(f"创建用户别名: end_user_id={end_user_id}, alias={alias}") - return user_alias - - def get_by_id(self, alias_id: uuid.UUID) -> Optional[UserAlias]: - """根据ID获取别名""" - return self.db.query(UserAlias).filter(UserAlias.id == alias_id).first() - - def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[UserAlias]: - """获取用户的所有别名""" - return self.db.query(UserAlias).filter(UserAlias.end_user_id == end_user_id).all() - - def update(self, alias_id: uuid.UUID, alias: str = None, meta_data: dict = None) -> Optional[UserAlias]: - """更新别名""" - user_alias = self.get_by_id(alias_id) - if user_alias: - if alias is not None: - user_alias.alias = alias - if meta_data is not None: - user_alias.meta_data = meta_data - self.db.commit() - self.db.refresh(user_alias) - logger.info(f"更新用户别名: alias_id={alias_id}") - return user_alias - - def delete(self, alias_id: uuid.UUID) -> bool: - """删除别名""" - user_alias = self.get_by_id(alias_id) - if user_alias: - self.db.delete(user_alias) - self.db.commit() - logger.info(f"删除用户别名: alias_id={alias_id}") - return True - return False - - def delete_by_end_user_id(self, end_user_id: uuid.UUID) -> int: - """删除用户的所有别名""" - count = self.db.query(UserAlias).filter(UserAlias.end_user_id == end_user_id).delete() - self.db.commit() - logger.info(f"删除用户所有别名: end_user_id={end_user_id}, count={count}") - return count - - def batch_create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str]) -> List[UserAlias]: - """批量创建别名""" - user_aliases = [] - for alias in aliases: - if alias and alias.strip(): - user_alias = UserAlias( - end_user_id=end_user_id, - other_name=other_name, - alias=alias.strip() - ) - self.db.add(user_alias) - user_aliases.append(user_alias) - - self.db.commit() - for user_alias in user_aliases: - self.db.refresh(user_alias) - - logger.info(f"批量创建用户别名: end_user_id={end_user_id}, count={len(user_aliases)}") - return user_aliases diff --git a/api/app/schemas/user_alias_schema.py b/api/app/schemas/end_user_info_schema.py similarity index 67% rename from api/app/schemas/user_alias_schema.py rename to api/app/schemas/end_user_info_schema.py index a8bf7700..f508190e 100644 --- a/api/app/schemas/user_alias_schema.py +++ b/api/app/schemas/end_user_info_schema.py @@ -5,31 +5,31 @@ from pydantic import BaseModel, Field from pydantic import ConfigDict -class UserAliasBase(BaseModel): - """用户别名基础模型""" +class EndUserInfoBase(BaseModel): + """终端用户信息基础模型""" other_name: str = Field(description="关联的用户名称") aliases: Optional[List[str]] = Field(description="用户别名列表", default=None) meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) -class UserAliasCreate(UserAliasBase): - """创建用户别名请求模型""" +class EndUserInfoCreate(EndUserInfoBase): + """创建终端用户信息请求模型""" end_user_id: str = Field(description="关联的终端用户ID") -class UserAliasUpdate(BaseModel): - """更新用户别名请求模型""" - user_alias_id: str = Field(description="用户别名记录ID") +class EndUserInfoUpdate(BaseModel): + """更新终端用户信息请求模型""" + end_user_info_id: str = Field(description="终端用户信息记录ID") other_name: Optional[str] = Field(description="用户名称", default=None) aliases: Optional[List[str]] = Field(description="用户别名列表", default=None) meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) -class UserAliasResponse(UserAliasBase): - """用户别名响应模型""" +class EndUserInfoResponse(EndUserInfoBase): + """终端用户信息响应模型""" model_config = ConfigDict(from_attributes=True) - user_alias_id: uuid.UUID = Field(description="用户别名记录ID") + end_user_info_id: uuid.UUID = Field(description="终端用户信息记录ID") end_user_id: uuid.UUID = Field(description="关联的终端用户ID") created_at: datetime.datetime = Field(description="创建时间") updated_at: datetime.datetime = Field(description="更新时间") diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 0a01b6dc..ba35c22f 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -362,17 +362,17 @@ class UserMemoryService: data[key] = UserMemoryService._datetime_to_timestamp(original_value) return data # ======================== 用户别名及信息 ======================== - def get_user_alias( + def get_end_user_info( self, db: Session, - user_alias_id: str + end_user_info_id: str ) -> Dict[str, Any]: """ - 查询单个用户别名记录 + 查询单个终端用户信息记录 Args: db: 数据库会话 - user_alias_id: 用户别名记录ID (UUID) + end_user_info_id: 终端用户信息记录ID (UUID) Returns: { @@ -382,33 +382,33 @@ class UserMemoryService: } """ try: - from app.models.user_alias_model import UserAlias + from app.models.end_user_info_model import EndUserInfo # 转换为UUID并查询 - alias_uuid = uuid.UUID(user_alias_id) - user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + info_uuid = uuid.UUID(end_user_info_id) + end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() - if not user_alias_record: - logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + if not end_user_info_record: + logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") return { "success": False, "data": None, - "error": "用户别名记录不存在" + "error": "终端用户信息记录不存在" } # 构建响应数据 - from app.schemas.user_alias_schema import UserAliasResponse - response_data = UserAliasResponse( - user_alias_id=user_alias_record.id, - end_user_id=user_alias_record.end_user_id, - other_name=user_alias_record.other_name, - aliases=user_alias_record.aliases, - meta_data=user_alias_record.meta_data, - created_at=user_alias_record.created_at, - updated_at=user_alias_record.updated_at + from app.schemas.end_user_info_schema import EndUserInfoResponse + response_data = EndUserInfoResponse( + end_user_info_id=end_user_info_record.id, + end_user_id=end_user_info_record.end_user_id, + other_name=end_user_info_record.other_name, + aliases=end_user_info_record.aliases, + meta_data=end_user_info_record.meta_data, + created_at=end_user_info_record.created_at, + updated_at=end_user_info_record.updated_at ) - logger.info(f"成功查询用户别名记录: user_alias_id={user_alias_id}") + logger.info(f"成功查询终端用户信息记录: end_user_info_id={end_user_info_id}") return { "success": True, @@ -417,21 +417,21 @@ class UserMemoryService: } except ValueError: - logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") return { "success": False, "data": None, - "error": "无效的用户别名记录ID格式" + "error": "无效的终端用户信息记录ID格式" } except Exception as e: - logger.error(f"查询用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + logger.error(f"查询终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") return { "success": False, "data": None, "error": str(e) } - def create_user_alias( + def create_end_user_info( self, db: Session, end_user_id: str, @@ -440,7 +440,7 @@ class UserMemoryService: meta_data: dict = None ) -> Dict[str, Any]: """ - 创建用户别名记录 + 创建终端用户信息记录 Args: db: 数据库会话 @@ -457,7 +457,7 @@ class UserMemoryService: } """ try: - from app.models.user_alias_model import UserAlias + from app.models.end_user_info_model import EndUserInfo from app.repositories.end_user_repository import EndUserRepository # 转换为UUID并查询用户 @@ -473,30 +473,30 @@ class UserMemoryService: "error": "终端用户不存在" } - # 创建新的别名记录 - new_alias = UserAlias( + # 创建新的用户信息记录 + new_info = EndUserInfo( end_user_id=user_uuid, other_name=other_name, aliases=aliases, meta_data=meta_data ) - db.add(new_alias) + db.add(new_info) db.commit() - db.refresh(new_alias) + db.refresh(new_info) # 构建响应数据 - from app.schemas.user_alias_schema import UserAliasResponse - response_data = UserAliasResponse( - user_alias_id=new_alias.id, - end_user_id=new_alias.end_user_id, - other_name=new_alias.other_name, - aliases=new_alias.aliases, - meta_data=new_alias.meta_data, - created_at=new_alias.created_at, - updated_at=new_alias.updated_at + from app.schemas.end_user_info_schema import EndUserInfoResponse + response_data = EndUserInfoResponse( + end_user_info_id=new_info.id, + end_user_id=new_info.end_user_id, + other_name=new_info.other_name, + aliases=new_info.aliases, + meta_data=new_info.meta_data, + created_at=new_info.created_at, + updated_at=new_info.updated_at ) - logger.info(f"成功创建用户别名记录: end_user_id={end_user_id}") + logger.info(f"成功创建终端用户信息记录: end_user_id={end_user_id}") return { "success": True, @@ -513,25 +513,25 @@ class UserMemoryService: } except Exception as e: db.rollback() - logger.error(f"创建用户别名记录失败: end_user_id={end_user_id}, error={str(e)}") + logger.error(f"创建终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}") return { "success": False, "data": None, "error": str(e) } - def update_user_alias( + def update_end_user_info( self, db: Session, - user_alias_id: str, + end_user_info_id: str, update_data: Dict[str, Any] ) -> Dict[str, Any]: """ - 更新用户别名记录 + 更新终端用户信息记录 Args: db: 数据库会话 - user_alias_id: 用户别名记录ID (UUID) + end_user_info_id: 终端用户信息记录ID (UUID) update_data: 更新数据字典 Returns: @@ -542,18 +542,18 @@ class UserMemoryService: } """ try: - from app.models.user_alias_model import UserAlias + from app.models.end_user_info_model import EndUserInfo # 转换为UUID并查询 - alias_uuid = uuid.UUID(user_alias_id) - user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + info_uuid = uuid.UUID(end_user_info_id) + end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() - if not user_alias_record: - logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + if not end_user_info_record: + logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") return { "success": False, "data": None, - "error": "用户别名记录不存在" + "error": "终端用户信息记录不存在" } # 定义允许更新的字段白名单 @@ -562,28 +562,28 @@ class UserMemoryService: # 更新字段(仅允许白名单中的字段) for field, value in update_data.items(): if field in allowed_fields: - setattr(user_alias_record, field, value) + setattr(end_user_info_record, field, value) # 更新时间戳 - user_alias_record.updated_at = datetime.now() + end_user_info_record.updated_at = datetime.now() # 提交更改 db.commit() - db.refresh(user_alias_record) + db.refresh(end_user_info_record) # 构建响应数据 - from app.schemas.user_alias_schema import UserAliasResponse - response_data = UserAliasResponse( - user_alias_id=user_alias_record.id, - end_user_id=user_alias_record.end_user_id, - other_name=user_alias_record.other_name, - aliases=user_alias_record.aliases, - meta_data=user_alias_record.meta_data, - created_at=user_alias_record.created_at, - updated_at=user_alias_record.updated_at + from app.schemas.end_user_info_schema import EndUserInfoResponse + response_data = EndUserInfoResponse( + end_user_info_id=end_user_info_record.id, + end_user_id=end_user_info_record.end_user_id, + other_name=end_user_info_record.other_name, + aliases=end_user_info_record.aliases, + meta_data=end_user_info_record.meta_data, + created_at=end_user_info_record.created_at, + updated_at=end_user_info_record.updated_at ) - logger.info(f"成功更新用户别名记录: user_alias_id={user_alias_id}, updated_fields={list(update_data.keys())}") + logger.info(f"成功更新终端用户信息记录: end_user_info_id={end_user_info_id}, updated_fields={list(update_data.keys())}") return { "success": True, @@ -592,32 +592,32 @@ class UserMemoryService: } except ValueError: - logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") return { "success": False, "data": None, - "error": "无效的用户别名记录ID格式" + "error": "无效的终端用户信息记录ID格式" } except Exception as e: db.rollback() - logger.error(f"更新用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + logger.error(f"更新终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") return { "success": False, "data": None, "error": str(e) } - def delete_user_alias( + def delete_end_user_info( self, db: Session, - user_alias_id: str + end_user_info_id: str ) -> Dict[str, Any]: """ - 删除用户别名记录 + 删除终端用户信息记录 Args: db: 数据库会话 - user_alias_id: 用户别名记录ID (UUID) + end_user_info_id: 终端用户信息记录ID (UUID) Returns: { @@ -627,42 +627,42 @@ class UserMemoryService: } """ try: - from app.models.user_alias_model import UserAlias + from app.models.end_user_info_model import EndUserInfo # 转换为UUID并查询 - alias_uuid = uuid.UUID(user_alias_id) - user_alias_record = db.query(UserAlias).filter(UserAlias.id == alias_uuid).first() + info_uuid = uuid.UUID(end_user_info_id) + end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() - if not user_alias_record: - logger.warning(f"用户别名记录不存在: user_alias_id={user_alias_id}") + if not end_user_info_record: + logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") return { "success": False, "data": None, - "error": "用户别名记录不存在" + "error": "终端用户信息记录不存在" } # 删除记录 - db.delete(user_alias_record) + db.delete(end_user_info_record) db.commit() - logger.info(f"成功删除用户别名记录: user_alias_id={user_alias_id}") + logger.info(f"成功删除终端用户信息记录: end_user_info_id={end_user_info_id}") return { "success": True, - "data": {"user_alias_id": user_alias_id}, + "data": {"end_user_info_id": end_user_info_id}, "error": None } except ValueError: - logger.error(f"无效的 user_alias_id 格式: {user_alias_id}") + logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") return { "success": False, "data": None, - "error": "无效的用户别名记录ID格式" + "error": "无效的终端用户信息记录ID格式" } except Exception as e: db.rollback() - logger.error(f"删除用户别名记录失败: user_alias_id={user_alias_id}, error={str(e)}") + logger.error(f"删除终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") return { "success": False, "data": None, From e981f066a384a8408f104aabf187881590cefb2c Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 17:44:54 +0800 Subject: [PATCH 026/120] [changes] Remove the interface and modify the parameters passed in --- .../controllers/user_memory_controllers.py | 119 ++--------- api/app/schemas/end_user_info_schema.py | 2 +- api/app/services/user_memory_service.py | 193 ++---------------- 3 files changed, 37 insertions(+), 277 deletions(-) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index b4c33032..b0dc82a0 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -341,14 +341,14 @@ async def get_community_graph_data_api( @router.get("/end_user_info", response_model=ApiResponse) async def get_end_user_info( - end_user_info_id: str, + end_user_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: """ 查询终端用户信息记录 - 根据 end_user_info_id 查询单条终端用户信息记录。 + 根据 end_user_id 查询单条终端用户信息记录。 """ workspace_id = current_user.current_workspace_id @@ -357,69 +357,27 @@ async def get_end_user_info( return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"查询终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " + f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - result = user_memory_service.get_end_user_info(db, end_user_info_id) + result = user_memory_service.get_end_user_info(db, end_user_id) if result["success"]: - api_logger.info(f"成功查询终端用户信息: end_user_info_id={end_user_info_id}") + api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}") return success(data=result["data"], msg="查询成功") else: error_msg = result["error"] - api_logger.error(f"查询终端用户信息失败: end_user_info_id={end_user_info_id}, error={error_msg}") + api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}") if error_msg == "终端用户信息记录不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) - elif error_msg == "无效的终端用户信息记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) + elif error_msg == "无效的终端用户ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg) else: return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg) -@router.post("/end_user_info/create", response_model=ApiResponse) -async def create_end_user_info( - info_create: EndUserInfoCreate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -) -> dict: - """ - 创建终端用户信息记录 - - 为指定用户创建一条新的信息记录,支持多个别名。 - """ - workspace_id = current_user.current_workspace_id - end_user_id = info_create.end_user_id - - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试创建终端用户信息但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info( - f"创建终端用户信息请求: end_user_id={end_user_id}, aliases={info_create.aliases}, " - f"user={current_user.username}, workspace={workspace_id}" - ) - - result = user_memory_service.create_end_user_info( - db, end_user_id, info_create.other_name, info_create.aliases, info_create.meta_data - ) - - if result["success"]: - api_logger.info(f"成功创建终端用户信息: end_user_id={end_user_id}") - return success(data=result["data"], msg="创建成功") - else: - error_msg = result["error"] - api_logger.error(f"终端用户信息创建失败: end_user_id={end_user_id}, error={error_msg}") - - if error_msg == "终端用户不存在": - return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) - elif error_msg == "无效的用户ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg) - else: - return fail(BizCode.INTERNAL_ERROR, "终端用户信息创建失败", error_msg) - - @router.post("/end_user_info/updated", response_model=ApiResponse) async def update_end_user_info( info_update: EndUserInfoUpdate, @@ -429,86 +387,47 @@ async def update_end_user_info( """ 更新终端用户信息记录 - 根据 end_user_info_id 更新终端用户信息记录,支持批量更新多个别名。 + 根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。 示例请求体: { - "end_user_info_id": "2d4f57d4-639b-47aa-937a-d461bc2c2d53", + "end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", "other_name": "张三1", "aliases": ["小张", "张工"], "meta_data": {"position": "工程师", "department": "技术部"} } """ workspace_id = current_user.current_workspace_id - end_user_info_id = info_update.end_user_info_id + end_user_id = info_update.end_user_id if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") api_logger.info( - f"更新终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " + f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - # 获取更新数据(排除 end_user_info_id) - update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_info_id'}) + # 获取更新数据(排除 end_user_id) + update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) - result = user_memory_service.update_end_user_info(db, end_user_info_id, update_data) + result = user_memory_service.update_end_user_info(db, end_user_id, update_data) if result["success"]: - api_logger.info(f"成功更新终端用户信息: end_user_info_id={end_user_info_id}") + api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}") return success(data=result["data"], msg="更新成功") else: error_msg = result["error"] - api_logger.error(f"终端用户信息更新失败: end_user_info_id={end_user_info_id}, error={error_msg}") + api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") if error_msg == "终端用户信息记录不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) - elif error_msg == "无效的终端用户信息记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) + elif error_msg == "无效的终端用户ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg) else: return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg) - -@router.delete("/end_user_info", response_model=ApiResponse) -async def delete_end_user_info( - end_user_info_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -) -> dict: - """ - 删除终端用户信息记录 - - 根据 end_user_info_id 删除指定的终端用户信息记录。 - """ - workspace_id = current_user.current_workspace_id - - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试删除终端用户信息但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info( - f"删除终端用户信息请求: end_user_info_id={end_user_info_id}, user={current_user.username}, " - f"workspace={workspace_id}" - ) - - result = user_memory_service.delete_end_user_info(db, end_user_info_id) - - if result["success"]: - api_logger.info(f"成功删除终端用户信息: end_user_info_id={end_user_info_id}") - return success(data=result["data"], msg="删除成功") - else: - error_msg = result["error"] - api_logger.error(f"终端用户信息删除失败: end_user_info_id={end_user_info_id}, error={error_msg}") - - if error_msg == "终端用户信息记录不存在": - return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) - elif error_msg == "无效的终端用户信息记录ID格式": - return fail(BizCode.INVALID_USER_ID, "无效的终端用户信息记录ID格式", error_msg) - else: - return fail(BizCode.INTERNAL_ERROR, "终端用户信息删除失败", error_msg) - @router.get("/memory_space/timeline_memories", response_model=ApiResponse) async def memory_space_timeline_of_shared_memories( id: str, label: str, diff --git a/api/app/schemas/end_user_info_schema.py b/api/app/schemas/end_user_info_schema.py index f508190e..60bdf9d6 100644 --- a/api/app/schemas/end_user_info_schema.py +++ b/api/app/schemas/end_user_info_schema.py @@ -19,7 +19,7 @@ class EndUserInfoCreate(EndUserInfoBase): class EndUserInfoUpdate(BaseModel): """更新终端用户信息请求模型""" - end_user_info_id: str = Field(description="终端用户信息记录ID") + end_user_id: str = Field(description="终端用户ID") other_name: Optional[str] = Field(description="用户名称", default=None) aliases: Optional[List[str]] = Field(description="用户别名列表", default=None) meta_data: Optional[Dict[str, Any]] = Field(description="用户相关的扩展信息", default=None) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index ba35c22f..f5f885be 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -365,14 +365,14 @@ class UserMemoryService: def get_end_user_info( self, db: Session, - end_user_info_id: str + end_user_id: str ) -> Dict[str, Any]: """ 查询单个终端用户信息记录 Args: db: 数据库会话 - end_user_info_id: 终端用户信息记录ID (UUID) + end_user_id: 终端用户ID (UUID) Returns: { @@ -385,11 +385,11 @@ class UserMemoryService: from app.models.end_user_info_model import EndUserInfo # 转换为UUID并查询 - info_uuid = uuid.UUID(end_user_info_id) - end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() + user_uuid = uuid.UUID(end_user_id) + end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() if not end_user_info_record: - logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") + logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, "data": None, @@ -408,95 +408,7 @@ class UserMemoryService: updated_at=end_user_info_record.updated_at ) - logger.info(f"成功查询终端用户信息记录: end_user_info_id={end_user_info_id}") - - return { - "success": True, - "data": response_data.model_dump(), - "error": None - } - - except ValueError: - logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") - return { - "success": False, - "data": None, - "error": "无效的终端用户信息记录ID格式" - } - except Exception as e: - logger.error(f"查询终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") - return { - "success": False, - "data": None, - "error": str(e) - } - - def create_end_user_info( - self, - db: Session, - end_user_id: str, - other_name: str, - aliases: List[str] = None, - meta_data: dict = None - ) -> Dict[str, Any]: - """ - 创建终端用户信息记录 - - Args: - db: 数据库会话 - end_user_id: 终端用户ID (UUID) - other_name: 用户名称 - aliases: 别名列表 - meta_data: 扩展信息 - - Returns: - { - "success": bool, - "data": dict, - "error": Optional[str] - } - """ - try: - from app.models.end_user_info_model import EndUserInfo - from app.repositories.end_user_repository import EndUserRepository - - # 转换为UUID并查询用户 - user_uuid = uuid.UUID(end_user_id) - repo = EndUserRepository(db) - end_user = repo.get_by_id(user_uuid) - - if not end_user: - logger.warning(f"终端用户不存在: end_user_id={end_user_id}") - return { - "success": False, - "data": None, - "error": "终端用户不存在" - } - - # 创建新的用户信息记录 - new_info = EndUserInfo( - end_user_id=user_uuid, - other_name=other_name, - aliases=aliases, - meta_data=meta_data - ) - db.add(new_info) - db.commit() - db.refresh(new_info) - - # 构建响应数据 - from app.schemas.end_user_info_schema import EndUserInfoResponse - response_data = EndUserInfoResponse( - end_user_info_id=new_info.id, - end_user_id=new_info.end_user_id, - other_name=new_info.other_name, - aliases=new_info.aliases, - meta_data=new_info.meta_data, - created_at=new_info.created_at, - updated_at=new_info.updated_at - ) - - logger.info(f"成功创建终端用户信息记录: end_user_id={end_user_id}") + logger.info(f"成功查询终端用户信息记录: end_user_id={end_user_id}") return { "success": True, @@ -509,11 +421,10 @@ class UserMemoryService: return { "success": False, "data": None, - "error": "无效的用户ID格式" + "error": "无效的终端用户ID格式" } except Exception as e: - db.rollback() - logger.error(f"创建终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}") + logger.error(f"查询终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}") return { "success": False, "data": None, @@ -523,7 +434,7 @@ class UserMemoryService: def update_end_user_info( self, db: Session, - end_user_info_id: str, + end_user_id: str, update_data: Dict[str, Any] ) -> Dict[str, Any]: """ @@ -531,7 +442,7 @@ class UserMemoryService: Args: db: 数据库会话 - end_user_info_id: 终端用户信息记录ID (UUID) + end_user_id: 终端用户ID (UUID) update_data: 更新数据字典 Returns: @@ -545,11 +456,11 @@ class UserMemoryService: from app.models.end_user_info_model import EndUserInfo # 转换为UUID并查询 - info_uuid = uuid.UUID(end_user_info_id) - end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() + user_uuid = uuid.UUID(end_user_id) + end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() if not end_user_info_record: - logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") + logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, "data": None, @@ -583,7 +494,7 @@ class UserMemoryService: updated_at=end_user_info_record.updated_at ) - logger.info(f"成功更新终端用户信息记录: end_user_info_id={end_user_info_id}, updated_fields={list(update_data.keys())}") + logger.info(f"成功更新终端用户信息记录: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") return { "success": True, @@ -592,85 +503,15 @@ class UserMemoryService: } except ValueError: - logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") + logger.error(f"无效的 end_user_id 格式: {end_user_id}") return { "success": False, "data": None, - "error": "无效的终端用户信息记录ID格式" + "error": "无效的终端用户ID格式" } except Exception as e: db.rollback() - logger.error(f"更新终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") - return { - "success": False, - "data": None, - "error": str(e) - } - - def delete_end_user_info( - self, - db: Session, - end_user_info_id: str - ) -> Dict[str, Any]: - """ - 删除终端用户信息记录 - - Args: - db: 数据库会话 - end_user_info_id: 终端用户信息记录ID (UUID) - - Returns: - { - "success": bool, - "data": dict, - "error": Optional[str] - } - """ - try: - from app.models.end_user_info_model import EndUserInfo - - # 转换为UUID并查询 - info_uuid = uuid.UUID(end_user_info_id) - end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.id == info_uuid).first() - - if not end_user_info_record: - logger.warning(f"终端用户信息记录不存在: end_user_info_id={end_user_info_id}") - return { - "success": False, - "data": None, - "error": "终端用户信息记录不存在" - } - - # 删除记录 - db.delete(end_user_info_record) - db.commit() - - logger.info(f"成功删除终端用户信息记录: end_user_info_id={end_user_info_id}") - - return { - "success": True, - "data": {"end_user_info_id": end_user_info_id}, - "error": None - } - - except ValueError: - logger.error(f"无效的 end_user_info_id 格式: {end_user_info_id}") - return { - "success": False, - "data": None, - "error": "无效的终端用户信息记录ID格式" - } - except Exception as e: - db.rollback() - logger.error(f"删除终端用户信息记录失败: end_user_info_id={end_user_info_id}, error={str(e)}") - return { - "success": False, - "data": None, - "error": str(e) - } - except Exception as e: - db.rollback() - logger.error(f"用户别名记录更新失败: user_alias_id={user_alias_id}, error={str(e)}") + logger.error(f"更新终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}") return { "success": False, "data": None, From 7c0743eb8f9b3446428fc678807afdc386b28904 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 24 Mar 2026 18:38:55 +0800 Subject: [PATCH 027/120] [changes] Modify to a millisecond-level timestamp --- api/app/services/user_memory_service.py | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index f5f885be..ec3a7363 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -383,6 +383,7 @@ class UserMemoryService: """ try: from app.models.end_user_info_model import EndUserInfo + from app.core.api_key_utils import datetime_to_timestamp # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) @@ -396,23 +397,22 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } - # 构建响应数据 - from app.schemas.end_user_info_schema import EndUserInfoResponse - response_data = EndUserInfoResponse( - end_user_info_id=end_user_info_record.id, - end_user_id=end_user_info_record.end_user_id, - other_name=end_user_info_record.other_name, - aliases=end_user_info_record.aliases, - meta_data=end_user_info_record.meta_data, - created_at=end_user_info_record.created_at, - updated_at=end_user_info_record.updated_at - ) + # 构建响应数据(转换时间为毫秒时间戳) + response_data = { + "end_user_info_id": str(end_user_info_record.id), + "end_user_id": str(end_user_info_record.end_user_id), + "other_name": end_user_info_record.other_name, + "aliases": end_user_info_record.aliases, + "meta_data": end_user_info_record.meta_data, + "created_at": datetime_to_timestamp(end_user_info_record.created_at), + "updated_at": datetime_to_timestamp(end_user_info_record.updated_at) + } logger.info(f"成功查询终端用户信息记录: end_user_id={end_user_id}") return { "success": True, - "data": response_data.model_dump(), + "data": response_data, "error": None } @@ -454,6 +454,7 @@ class UserMemoryService: """ try: from app.models.end_user_info_model import EndUserInfo + from app.core.api_key_utils import datetime_to_timestamp # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) @@ -482,23 +483,22 @@ class UserMemoryService: db.commit() db.refresh(end_user_info_record) - # 构建响应数据 - from app.schemas.end_user_info_schema import EndUserInfoResponse - response_data = EndUserInfoResponse( - end_user_info_id=end_user_info_record.id, - end_user_id=end_user_info_record.end_user_id, - other_name=end_user_info_record.other_name, - aliases=end_user_info_record.aliases, - meta_data=end_user_info_record.meta_data, - created_at=end_user_info_record.created_at, - updated_at=end_user_info_record.updated_at - ) + # 构建响应数据(转换时间为毫秒时间戳) + response_data = { + "end_user_info_id": str(end_user_info_record.id), + "end_user_id": str(end_user_info_record.end_user_id), + "other_name": end_user_info_record.other_name, + "aliases": end_user_info_record.aliases, + "meta_data": end_user_info_record.meta_data, + "created_at": datetime_to_timestamp(end_user_info_record.created_at), + "updated_at": datetime_to_timestamp(end_user_info_record.updated_at) + } logger.info(f"成功更新终端用户信息记录: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") return { "success": True, - "data": response_data.model_dump(), + "data": response_data, "error": None } From 38c6c7f053bf69d17a8607785f0acdcb762f3271 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 16:26:41 +0800 Subject: [PATCH 028/120] [changes] Simultaneously create the "end_user_info" data to ensure that the interface modification takes effect immediately. --- .../core/memory/agent/utils/write_tools.py | 16 + .../extraction_orchestrator.py | 151 ++++++++++ .../prompt/prompts/extract_triplet.jinja2 | 282 ++++++++++++------ api/app/repositories/end_user_repository.py | 19 +- api/app/services/user_memory_service.py | 63 ++++ 5 files changed, 438 insertions(+), 93 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 6829cf57..5829a5c9 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -176,6 +176,22 @@ async def write( ) if success: logger.info("Successfully saved all data to Neo4j") + + # 同步用户别名到 PostgreSQL + try: + # 创建一个临时的 orchestrator 实例来调用同步方法 + temp_orchestrator = ExtractionOrchestrator( + llm_client=llm_client, + embedder_client=embedder_client, + connector=neo4j_connector, + embedding_id=embedding_model_id + ) + await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs) + logger.info("Successfully synced user aliases to PostgreSQL") + except Exception as sync_error: + logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True) + # 不影响主流程 + # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) await _trigger_clustering_sync( all_entity_nodes, diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index da10c497..d5681da9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -19,6 +19,7 @@ import asyncio import logging import os +import uuid from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple @@ -62,6 +63,10 @@ from app.core.memory.storage_services.extraction_engine.pipeline_help import ( export_test_input_doc, ) from app.core.memory.utils.data.ontology import TemporalInfo +from app.db import get_db_context +from app.models.end_user_info_model import EndUserInfo +from app.repositories.end_user_info_repository import EndUserInfoRepository +from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 配置日志 @@ -1325,6 +1330,152 @@ class ExtractionOrchestrator: perceptual_edges ) + async def _update_end_user_other_name( + self, + entity_nodes: List[ExtractedEntityNode], + dialog_data_list: List[DialogData] + ) -> None: + """ + 从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表 + + 注意: + 1. other_name 使用本次对话提取的第一个别名(保持时间顺序) + 2. aliases 从 Neo4j 读取(保持完整性) + + Args: + entity_nodes: 实体节点列表 + dialog_data_list: 对话数据列表 + """ + try: + if not dialog_data_list: + logger.warning("dialog_data_list 为空,跳过用户别名同步") + return + + end_user_id = dialog_data_list[0].end_user_id + if not end_user_id: + logger.warning("end_user_id 为空,跳过用户别名同步") + return + + # 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序) + current_aliases = self._extract_current_aliases(entity_nodes) + + # 2. 从 Neo4j 获取完整 aliases(权威数据源) + neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id) + + if not neo4j_aliases: + # Neo4j 中没有别名,使用本次对话提取的别名 + neo4j_aliases = current_aliases + if not neo4j_aliases: + logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}") + return + + logger.info(f"本次对话提取的 aliases: {current_aliases}") + logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}") + + # 3. 同步到数据库 + end_user_uuid = uuid.UUID(end_user_id) + with get_db_context() as db: + # 更新 end_user 表 + end_user = EndUserRepository(db).get_by_id(end_user_uuid) + if not end_user: + logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录") + return + + new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases) + if new_name is not None: + end_user.other_name = new_name + logger.info(f"更新 end_user 表 other_name → {new_name}") + else: + logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}") + + # 更新或创建 end_user_info 记录 + existing_infos = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + if existing_infos: + info = existing_infos[0] + new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases) + if new_name_info is not None: + info.other_name = new_name_info + logger.info(f"更新 end_user_info 表 other_name → {new_name_info}") + if info.aliases != neo4j_aliases: + info.aliases = neo4j_aliases + logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") + else: + first_alias = current_aliases[0].strip() if current_aliases else "" + if first_alias: + db.add(EndUserInfo( + end_user_id=end_user_uuid, + other_name=first_alias, + aliases=neo4j_aliases, + meta_data={} + )) + logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}") + + db.commit() + + except Exception as e: + logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True) + + + + def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: + """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) + + 这个方法直接返回 LLM 提取的别名列表,不做任何修改。 + 第一个别名将被用作 other_name。 + + Args: + entity_nodes: 实体节点列表 + + Returns: + 别名列表(保持 LLM 提取的原始顺序) + """ + USER_NAMES = {'用户', '我', 'User', 'I'} + for entity in entity_nodes: + if getattr(entity, 'name', '').strip() in USER_NAMES: + aliases = getattr(entity, 'aliases', []) or [] + logger.debug(f"提取到用户别名(原始顺序): {aliases}") + return aliases + return [] + + + async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: + """从 Neo4j 查询用户实体的完整 aliases 列表""" + cypher = """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] + RETURN e.aliases AS aliases + LIMIT 1 + """ + result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id) + if not result: + logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}") + return [] + aliases = result[0].get('aliases') or [] + if not aliases: + logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") + return aliases + + def _resolve_other_name( + self, + current: Optional[str], + current_aliases: List[str], + neo4j_aliases: List[str] + ) -> Optional[str]: + """ + 决定 other_name 是否需要更新,返回新值;无需更新返回 None。 + + 决策规则: + - 为空 → 用本次对话第一个别名 + - 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除) + - 否则 → 保持不变(返回 None) + """ + if not current or not current.strip(): + return current_aliases[0].strip() if current_aliases else None + if current not in neo4j_aliases: + return neo4j_aliases[0].strip() if neo4j_aliases else None + + return None + async def _run_dedup_and_write_summary( self, dialogue_nodes: List[DialogueNode], diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 25fffa33..09e6ff8d 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -5,6 +5,15 @@ ===Task=== Extract entities and knowledge triplets from the given statement. +**⚠️ CRITICAL REQUIREMENTS:** +1. **ALIASES ORDER IS CRITICAL**: The FIRST alias in the array will be used as the user's primary display name (other_name). You MUST put the most important/frequently used name FIRST. +2. **ALWAYS include aliases field**: Even if empty, you MUST include "aliases": [] in EVERY entity. + + + {% if language == "zh" %} **重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。** {% else %} @@ -18,34 +27,29 @@ Extract entities and knowledge triplets from the given statement. {% if ontology_types %} ===Ontology Type Guidance=== -**CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.** +**CRITICAL: Use ONLY predefined type names below. If no exact match, use CLOSEST type. NEVER invent new types.** -**If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.** +**Type Priority:** +1. [场景类型] Scene Types (domain-specific, prefer first) +2. [通用类型] General Types (standard ontologies) +3. [通用父类] Parent Types (hierarchy context) -**Type Priority (from highest to lowest):** -1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first -2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia) -3. **[通用父类] Parent Types** - Provide type hierarchy context +**Rules:** +- Type MUST exactly match predefined names +- Do NOT modify, translate, or abbreviate type names +- Prefer scene types over general types -**Type Matching Rules:** -- Entity type MUST exactly match one of the predefined type names below -- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list -- Do NOT modify, translate, abbreviate, or create variations of type names -- Prefer scene types (marked [场景类型]) over general types when both could apply -- If uncertain, check the type description to find the best match - -**Predefined Ontology Types:** +**Predefined Types:** {{ ontology_types }} {% if type_hierarchy_hints %} -**Type Hierarchy Reference:** -The following shows type inheritance relationships (Child → Parent → Grandparent): +**Hierarchy:** {% for hint in type_hierarchy_hints %} - {{ hint }} {% endfor %} {% endif %} -**ALLOWED Type Names (use EXACTLY one of these, no exceptions):** +**ALLOWED Names:** {{ ontology_type_names | join(', ') }} {% endif %} @@ -62,75 +66,114 @@ The following shows type inheritance relationships (Child → Parent → Grandpa - **Entity descriptions must be in English** - **Examples must be in English** {% endif %} -- **Semantic Memory Classification (is_explicit_memory):** - * Set to `true` if the entity represents **explicit/semantic memory**: - - **Concepts:** "Machine Learning", "Photosynthesis", "Democracy" - - **Knowledge:** "Python Programming Language", "Theory of Relativity" - - **Definitions:** "API (Application Programming Interface)", "REST API" - - **Principles:** "SOLID Principles", "First Law of Thermodynamics" - - **Theories:** "Evolution Theory", "Quantum Mechanics" - - **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm" - - **Technical Terms:** "Neural Network", "Database" - * Set to `false` for: - - **People:** "John Smith", "Dr. Wang" - - **Organizations:** "Microsoft", "Harvard University" - - **Locations:** "Beijing", "Central Park" - - **Events:** "2024 Conference", "Project Meeting" - - **Specific objects:** "iPhone 15", "Building A" -- **Example Generation (IMPORTANT for semantic memory entities):** - * For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept - * The example should be: - - **Specific and concrete**: Use real-world scenarios or applications - - **Brief**: Around 20 characters (can be slightly longer if needed for clarity) -{% if language == "zh" %} - - **使用中文** -{% else %} - - **In English** -{% endif %} - * For non-semantic entities (`is_explicit_memory=false`), the example field can be empty -- **Aliases Extraction:** -{% if language == "zh" %} - * 别名使用中文 -{% else %} - * Aliases should be in English -{% endif %} - * Include common alternative names, abbreviations and full names - * If no aliases exist, use empty array: [] +- **Semantic Memory (is_explicit_memory):** + * `true` for: Concepts, Knowledge, Definitions, Theories, Methods (e.g., "Machine Learning", "REST API") + * `false` for: People, Organizations, Locations, Events, Specific objects + * For `is_explicit_memory=true`, provide concise example (~20 chars{% if language == "zh" %},使用中文{% endif %}) - **姓名别名识别规则(Name Alias Recognition):** - * 当前对话的用户实体 name 固定为"用户",不得使用用户透露的真实姓名作为 name - * 自我称呼模式:用户说"我的名字是X"、"我叫X" → X 加入 aliases(name 保持为"用户") - * 昵称/小名模式:识别"小名"、"昵称"、"英文名"、"网名"等关键词后的称呼 → 加入 aliases - * 他人称呼模式:识别"同事叫我X"、"朋友叫我X"、"大家叫我X" → X 加入 aliases - * 同一实体的多个称呼应合并到同一 Entity 的 aliases 列表中 - * aliases 中不应包含与 name 完全相同的字符串 - * **严禁将已加入某实体 aliases 的词再单独抽取为另一个独立实体**:若某个词已作为别名归属于"用户"实体,则不得再将该词作为独立 Entity 的 name 出现在 entities 列表中 -- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions -- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value) - Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric' +**🚨🚨🚨 ALIASES & DENIED_ALIASES - MANDATORY FIELDS 🚨🚨🚨** + +**CRITICAL RULES (违反将导致提取失败):** + +1. **EVERY entity MUST have BOTH fields:** + - `"aliases": [...]` - REQUIRED, even if empty `[]` + - `"denied_aliases": [...]` - REQUIRED, even if empty `[]` + +2. **ALIASES - 别名提取规则:** +{% if language == "zh" %} + - 包含:昵称、全名、简称、别称、网名等 + - 顺序:**第一个别名将作为用户的主显示名称(other_name),必须把最重要/最常用的名字放在第一位** + - 提取顺序:严格按照对话中首次出现的顺序 + - 示例: + * "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name) + * "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name) + - 空值:如果没有别名,使用 `[]` + - 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字 +{% else %} + - Include: nicknames, full names, abbreviations, alternative names + - Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST** + - Extraction order: Strictly follow the order of first appearance in conversation + - Examples: + * "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name) + * "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name) + - Empty: If no aliases, use `[]` + - Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names +{% endif %} + + + +4. **USER ENTITY SPECIAL HANDLING:** +{% if language == "zh" %} + - 用户实体的 name 字段:使用 "用户" 或 "我" + - 用户的真实姓名:放入 aliases + - 示例: + * "我叫李明" → name="用户", aliases=["李明"] +{% else %} + - User entity name field: use "User" or "I" + - User's real name: put in aliases + - Examples: + * "I'm John" → name="User", aliases=["John"] +{% endif %} + + + +5. **CONFLICT RESOLUTION:** +{% if language == "zh" %} + - 顺序优先级:按出现顺序,先出现的在前 +{% else %} + - Order priority: by appearance order, first mentioned comes first +{% endif %} + + + +**EXAMPLES OF CORRECT EXTRACTION:** +{% if language == "zh" %} +- "我叫张三" → aliases=["张三"] (张三将成为 other_name) +- "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name) +- "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name) + + +{% else %} +- "I'm John" → aliases=["John"] (John will become other_name) +- "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name) +- "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name) + + +{% endif %} + +- Exclude lengthy quotes, dates, temporal expressions +- Numeric values: extract as entities (instance_of: 'Numeric', name: units, numeric_value: value) **Triplet Extraction:** -- Extract (subject, predicate, object) triplets where: - - Subject: main entity performing the action or being described - - Predicate: relationship between entities (e.g., 'is', 'works at', 'believes') - - Object: entity, value, or concept affected by the predicate +- Extract (subject, predicate, object) where subject/object are entities, predicate is relationship {% if language == "zh" %} -- subject_name 和 object_name 必须使用中文 +- subject_name 和 object_name 使用中文 {% else %} -- subject_name and object_name must be in English (translate if original is in another language) +- subject_name and object_name in English {% endif %} -- Exclude all temporal expressions from every field -- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens) -- Do NOT translate predicate tokens -- Do NOT include `statement_id` field (assigned automatically) - -**When NOT to extract triplets:** -- Non-propositional utterances (emotions, fillers, onomatopoeia) -- No clear predicate from the given definitions applies -- Standalone noun phrases or checklist items → extract as entities only -- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS") - -**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty. +- Use ONLY predicates from "Predicate Instructions" (uppercase tokens) +- Exclude temporal expressions, do NOT include `statement_id` +- **When NOT to extract:** emotions, fillers, no clear predicate, standalone nouns +- **If no valid triplet:** Return triplets: [] {%- if predicate_instructions -%} **Predicate Instructions:** @@ -217,34 +260,91 @@ Output: ] } -**Example 4 (姓名别名识别 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐" +**Example 4 (别名 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐" Output: { "triplets": [], "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人,有多个称呼", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false} + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false} + ] +} + +**Example 5 (别名顺序 - Chinese):** "我叫陈思远。对了,我的网名叫「远山」" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远", "远山"], "is_explicit_memory": false} + ] +} + +**Example 6 (否定别名 - Chinese):** "我不叫陈思远,我其实叫小小张" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["小小张"], "denied_aliases": ["陈思远"], "is_explicit_memory": false} + ] +} + +**Example 7 (否定别名 - Chinese):** "我不叫远山" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "denied_aliases": ["远山"], "is_explicit_memory": false} + ] +} + +**Example 8 (复杂场景 - Chinese):** "大家都叫我明明,我的全名是小明,但我不是小红" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["明明", "小明"], "denied_aliases": ["小红"], "is_explicit_memory": false} + ] +} + +**Example 9 (纠正错误 - Chinese):** "我搞错了,我的网名不叫做远山,网名叫做大山" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["大山"], "denied_aliases": ["远山"], "is_explicit_memory": false} + ] +} + +**Example 10 (多重纠正 - Chinese):** "其实我不是老张,也不叫小张,我叫张三" +Output: +{ + "triplets": [], + "entities": [ + {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["张三"], "denied_aliases": ["老张", "小张"], "is_explicit_memory": false} ] } {% endif %} ===End of Examples=== {% if ontology_types %} -**⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.** +**⚠️ REMINDER: Examples use generic types for illustration. You MUST use predefined types from "ALLOWED Names" above.** {% endif %} ===Output Format=== **JSON Requirements:** -- Use only ASCII double quotes (") for JSON structure -- Never use Chinese quotation marks ("") or Unicode quotes -- Escape quotation marks in text with backslashes (\") -- Ensure proper string closure and comma separation -- No line breaks within JSON string values +- Use ASCII double quotes ("), escape with \" +- No Chinese quotes (""), no line breaks in strings {% if language == "zh" %} -- **语言要求:实体名称(name)、描述(description)、示例(example)、subject_name、object_name 必须使用中文** +- **语言:name、description、example、subject_name、object_name 使用中文** {% else %} -- **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English** -- **If the original text is in Chinese, translate all names to English** +- **Language: names, descriptions, examples in English (translate if needed)** {% endif %} +- **⚠️ ALIASES ORDER: preserve temporal order of appearance** +- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []** + + {{ json_schema }} diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 71c93634..8ac65cd7 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger from app.models.app_model import App from app.models.end_user_model import EndUser +from app.models.end_user_info_model import EndUserInfo from app.models.workspace_model import Workspace # 获取数据库专用日志器 @@ -70,7 +71,8 @@ class EndUserRepository: app_id: uuid.UUID, workspace_id: uuid.UUID, other_id: str, - original_user_id: Optional[str] = None + original_user_id: Optional[str] = None, + other_name: Optional[str] = None ) -> EndUser: """获取或创建终端用户 @@ -79,6 +81,7 @@ class EndUserRepository: workspace_id: 工作空间ID other_id: 第三方ID original_user_id: 原始用户ID (存储到 other_id) + other_name: 用户名称(用于创建 EndUserInfo) """ try: # 尝试查找现有用户 @@ -106,10 +109,22 @@ class EndUserRepository: other_id=other_id ) self.db.add(end_user) + self.db.flush() # 刷新以获取 end_user.id,但不提交事务 + + # 创建对应的 EndUserInfo 记录 + end_user_info = EndUserInfo( + end_user_id=end_user.id, + other_name=other_name or "", # 如果没有提供 other_name,使用空字符串 + aliases={}, # 空字典而不是 None + meta_data={} # 空字典而不是 None + ) + self.db.add(end_user_info) + + # 一起提交 self.db.commit() self.db.refresh(end_user) - db_logger.info(f"创建新终端用户: (other_id: {other_id}) for workspace {workspace_id}") + db_logger.info(f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}") return end_user except Exception as e: diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index ec3a7363..4106a1b0 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -454,6 +454,7 @@ class UserMemoryService: """ try: from app.models.end_user_info_model import EndUserInfo + from app.models.end_user_model import EndUser from app.core.api_key_utils import datetime_to_timestamp # 转换为UUID并查询 @@ -471,6 +472,12 @@ class UserMemoryService: # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} + # 检查是否更新了 aliases 字段 + aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases + + # 检查是否更新了 other_name 字段 + other_name_updated = 'other_name' in update_data and update_data['other_name'] != end_user_info_record.other_name + # 更新字段(仅允许白名单中的字段) for field, value in update_data.items(): if field in allowed_fields: @@ -479,10 +486,30 @@ class UserMemoryService: # 更新时间戳 end_user_info_record.updated_at = datetime.now() + # 如果 other_name 被更新,同步更新 end_user 表 + if other_name_updated: + end_user_record = db.query(EndUser).filter(EndUser.id == user_uuid).first() + if end_user_record: + end_user_record.other_name = update_data['other_name'] + end_user_record.updated_at = datetime.now() + logger.info(f"同步更新 end_user 表的 other_name: end_user_id={end_user_id}, other_name={update_data['other_name']}") + else: + logger.warning(f"未找到对应的 end_user 记录: end_user_id={end_user_id}") + # 提交更改 db.commit() db.refresh(end_user_info_record) + # 如果 aliases 被更新,同步到 Neo4j + if aliases_updated: + try: + import asyncio + asyncio.create_task(self._sync_aliases_to_neo4j(end_user_id, update_data['aliases'])) + logger.info(f"已触发 aliases 同步到 Neo4j: end_user_id={end_user_id}, aliases={update_data['aliases']}") + except Exception as sync_error: + logger.error(f"触发同步 aliases 到 Neo4j 失败: {sync_error}", exc_info=True) + # 不影响主流程,只记录错误 + # 构建响应数据(转换时间为毫秒时间戳) response_data = { "end_user_info_id": str(end_user_info_record.id), @@ -518,6 +545,42 @@ class UserMemoryService: "error": str(e) } + async def _sync_aliases_to_neo4j(self, end_user_id: str, aliases: List[str]) -> None: + """ + 将 aliases 同步到 Neo4j 中的用户实体 + + Args: + end_user_id: 终端用户ID + aliases: 别名列表 + """ + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + # Cypher 查询:更新用户实体的 aliases + cypher_query = """ + MATCH (e:ExtractedEntity) + WHERE e.end_user_id = $end_user_id + AND e.name IN ['用户', '我', 'User', 'I'] + SET e.aliases = $aliases + RETURN e.id AS entity_id, e.name AS entity_name, e.aliases AS updated_aliases + """ + + connector = Neo4jConnector() + try: + result = await connector.execute_query( + cypher_query, + end_user_id=end_user_id, + aliases=aliases + ) + + if result: + logger.info(f"成功同步 aliases 到 Neo4j: end_user_id={end_user_id}, 更新了 {len(result)} 个实体节点") + else: + logger.warning(f"未找到需要更新的用户实体节点: end_user_id={end_user_id}") + + except Exception as e: + logger.error(f"同步 aliases 到 Neo4j 失败: {e}", exc_info=True) + raise + async def get_cached_memory_insight( self, db: Session, From 1e986c641f2337dc85e99944c11808bd0dcde61f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 16:36:20 +0800 Subject: [PATCH 029/120] [fix] Fix the code according to the comments --- api/app/models/__init__.py | 2 +- api/app/models/end_user_info_model.py | 4 +-- .../repositories/end_user_info_repository.py | 32 ++++--------------- api/app/services/user_memory_service.py | 22 ++++++++----- 4 files changed, 23 insertions(+), 37 deletions(-) diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 7dd26d34..e889504a 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -61,7 +61,7 @@ __all__ = [ "AppRelease", "MemoryIncrement", "EndUser", - "UserAlias", + "EndUserInfo", "AppShare", "ReleaseShare", "Conversation", diff --git a/api/app/models/end_user_info_model.py b/api/app/models/end_user_info_model.py index ed747002..c02f254c 100644 --- a/api/app/models/end_user_info_model.py +++ b/api/app/models/end_user_info_model.py @@ -1,7 +1,7 @@ import datetime import uuid -from sqlalchemy import Column, DateTime, ForeignKey, String, Text +from sqlalchemy import Column, DateTime, ForeignKey, String, Text, ARRAY from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.orm import relationship @@ -15,7 +15,7 @@ class EndUserInfo(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID") other_name = Column(String, nullable=False, comment="关联的用户名称") - aliases = Column(JSONB, nullable=True, comment="用户别名列表(JSON数组)") + aliases = Column(ARRAY(String), nullable=True, comment="用户别名列表(字符串数组)") meta_data = Column(JSONB, nullable=True, comment="用户相关的扩展信息(JSON格式)") created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") diff --git a/api/app/repositories/end_user_info_repository.py b/api/app/repositories/end_user_info_repository.py index ee05d12d..f9f4665c 100644 --- a/api/app/repositories/end_user_info_repository.py +++ b/api/app/repositories/end_user_info_repository.py @@ -17,18 +17,18 @@ class EndUserInfoRepository: def __init__(self, db: Session): self.db = db - def create(self, end_user_id: uuid.UUID, other_name: str, alias: str = None, meta_data: dict = None) -> EndUserInfo: + def create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str] = None, meta_data: dict = None) -> EndUserInfo: """创建终端用户信息""" end_user_info = EndUserInfo( end_user_id=end_user_id, other_name=other_name, - alias=alias, + aliases=aliases or [], meta_data=meta_data ) self.db.add(end_user_info) self.db.commit() self.db.refresh(end_user_info) - logger.info(f"创建终端用户信息: end_user_id={end_user_id}, alias={alias}") + logger.info(f"创建终端用户信息: end_user_id={end_user_id}, aliases={aliases}") return end_user_info def get_by_id(self, info_id: uuid.UUID) -> Optional[EndUserInfo]: @@ -39,12 +39,12 @@ class EndUserInfoRepository: """获取用户的所有信息记录""" return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all() - def update(self, info_id: uuid.UUID, alias: str = None, meta_data: dict = None) -> Optional[EndUserInfo]: + def update(self, info_id: uuid.UUID, aliases: List[str] = None, meta_data: dict = None) -> Optional[EndUserInfo]: """更新用户信息""" end_user_info = self.get_by_id(info_id) if end_user_info: - if alias is not None: - end_user_info.alias = alias + if aliases is not None: + end_user_info.aliases = aliases if meta_data is not None: end_user_info.meta_data = meta_data self.db.commit() @@ -68,23 +68,3 @@ class EndUserInfoRepository: self.db.commit() logger.info(f"删除用户所有信息记录: end_user_id={end_user_id}, count={count}") return count - - def batch_create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str]) -> List[EndUserInfo]: - """批量创建用户信息""" - end_user_infos = [] - for alias in aliases: - if alias and alias.strip(): - end_user_info = EndUserInfo( - end_user_id=end_user_id, - other_name=other_name, - alias=alias.strip() - ) - self.db.add(end_user_info) - end_user_infos.append(end_user_info) - - self.db.commit() - for end_user_info in end_user_infos: - self.db.refresh(end_user_info) - - logger.info(f"批量创建终端用户信息: end_user_id={end_user_id}, count={len(end_user_infos)}") - return end_user_infos diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 4106a1b0..28dabe10 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -382,14 +382,14 @@ class UserMemoryService: } """ try: - from app.models.end_user_info_model import EndUserInfo + from app.repositories.end_user_info_repository import EndUserInfoRepository from app.core.api_key_utils import datetime_to_timestamp # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() + end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_record: + if not end_user_info_records: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -397,6 +397,9 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } + # 获取第一条记录 + end_user_info_record = end_user_info_records[0] + # 构建响应数据(转换时间为毫秒时间戳) response_data = { "end_user_info_id": str(end_user_info_record.id), @@ -453,15 +456,15 @@ class UserMemoryService: } """ try: - from app.models.end_user_info_model import EndUserInfo - from app.models.end_user_model import EndUser + from app.repositories.end_user_info_repository import EndUserInfoRepository + from app.repositories.end_user_repository import EndUserRepository from app.core.api_key_utils import datetime_to_timestamp # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() + end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_record: + if not end_user_info_records: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -469,6 +472,9 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } + # 获取第一条记录 + end_user_info_record = end_user_info_records[0] + # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} @@ -488,7 +494,7 @@ class UserMemoryService: # 如果 other_name 被更新,同步更新 end_user 表 if other_name_updated: - end_user_record = db.query(EndUser).filter(EndUser.id == user_uuid).first() + end_user_record = EndUserRepository(db).get_by_id(user_uuid) if end_user_record: end_user_record.other_name = update_data['other_name'] end_user_record.updated_at = datetime.now() From 2a12cb04bfa7c83f394b5bb272dcab0cd8e77ffc Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 17:04:53 +0800 Subject: [PATCH 030/120] [changes] Optimize the Cypher query statement --- api/app/repositories/neo4j/cypher_queries.py | 39 +++++++------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index f80b7e26..c08f9d0e 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -339,6 +339,19 @@ LIMIT $limit SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +WITH e, score +UNION +MATCH (e:ExtractedEntity) +WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) + AND e.aliases IS NOT NULL + AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) +WITH e, + CASE + WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 + WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + ELSE 0.8 + END AS score +WITH DISTINCT e, MAX(score) AS score OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) RETURN e.id AS id, @@ -360,32 +373,6 @@ RETURN e.id AS id, e.last_access_time AS last_access_time, COALESCE(e.access_count, 0) AS access_count, score -UNION -MATCH (e:ExtractedEntity) -WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) - AND e.aliases IS NOT NULL - AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) -OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) -OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) -RETURN e.id AS id, - e.name AS name, - e.end_user_id AS end_user_id, - e.entity_type AS entity_type, - e.created_at AS created_at, - e.expired_at AS expired_at, - e.entity_idx AS entity_idx, - e.statement_id AS statement_id, - e.description AS description, - e.aliases AS aliases, - e.name_embedding AS name_embedding, - e.connect_strength AS connect_strength, - collect(DISTINCT s.id) AS statement_ids, - collect(DISTINCT c.id) AS chunk_ids, - COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, - COALESCE(e.importance_score, 0.5) AS importance_score, - e.last_access_time AS last_access_time, - COALESCE(e.access_count, 0) AS access_count, - 0.8 AS score ORDER BY score DESC LIMIT $limit """ From 14a32778f73c2535cdcba177b2d93c6d5d0dec76 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 24 Mar 2026 10:38:47 +0800 Subject: [PATCH 031/120] fix(memory-config): Resolve legacy config_id_old to UUID format - Update config ID validation to query config_id_old field instead of user_id - Raise InvalidConfigError when config_id_old mapping is not found instead of returning raw ID - Add _resolve_config_id_old method to map legacy integer config IDs to UUID format - Enhance agent memory config extraction to resolve legacy int/string formats to UUID - Improve workflow memory node config ID resolution with proper legacy format handling - Fix memory config serialization to always use UUID string format - Update log messages to clarify config_id_old field references and resolution status --- api/app/services/memory_config_service.py | 78 +++++++++++++++------- api/app/services/memory_storage_service.py | 6 +- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 1a4af531..66c110b1 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -52,16 +52,20 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - # 如果提供了数据库会话,尝试通过 user_id 查询 config_id + # 如果提供了数据库会话,尝试通过 config_id_old 查询 config_id if db is not None: - # 查询 user_id 匹配的记录 - stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == str(config_id)) + # 查询 config_id_old 匹配的记录 + stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == config_id) result = db.execute(stmt).scalars().first() if result: - logger.info(f"Found config_id {result.config_id} for user_id {config_id}") + logger.info(f"Found config_id {result.config_id} for config_id_old {config_id}") return result.config_id - - return config_id + + raise InvalidConfigError( + f"未找到 config_id_old={config_id} 对应的配置", + field_name="config_id", + invalid_value=config_id, + ) if isinstance(config_id, str): config_id_stripped = config_id.strip() @@ -84,15 +88,19 @@ def _validate_config_id(config_id, db: Session = None): # 如果提供了数据库会话,尝试通过 user_id 查询 config_id if db is not None: - # 查询 user_id 匹配的记录 - stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id)) + # 查询 config_id_old 匹配的记录 + stmt = select(MemoryConfigModel).where(MemoryConfigModel.config_id_old == parsed_id) result = db.execute(stmt).scalars().first() if result: - logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}") + logger.info(f"Found config_id {result.config_id} for config_id_old {parsed_id}") return result.config_id - - return parsed_id + + raise InvalidConfigError( + f"未找到 config_id_old={parsed_id} 对应的配置", + field_name="config_id", + invalid_value=config_id, + ) except ValueError: raise InvalidConfigError( f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)", @@ -869,6 +877,23 @@ class MemoryConfigService: logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}") return None, False + def _resolve_config_id_old(self, config_id_old: int) -> Optional[uuid.UUID]: + """通过 config_id_old 查询对应的 UUID config_id。 + + Args: + config_id_old: 旧格式的整数配置ID + + Returns: + 对应的 UUID config_id,未找到返回 None + """ + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + result = self.db.query(MemoryConfigModel).filter( + MemoryConfigModel.config_id_old == config_id_old + ).first() + if result: + return result.config_id + return None + def _extract_memory_config_id_from_agent( self, config: dict @@ -900,10 +925,11 @@ class MemoryConfigService: elif isinstance(memory_value, str): # Check if it's a numeric string (legacy int format) if memory_value.isdigit(): - logger.warning( - f"Agent 配置中 memory_config_id 为旧格式 int 字符串,将使用工作空间默认配置: " - f"value={memory_value}" - ) + resolved = self._resolve_config_id_old(int(memory_value)) + if resolved: + logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}") + return resolved, False + logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置") return None, True try: return uuid.UUID(memory_value), False @@ -911,11 +937,11 @@ class MemoryConfigService: logger.warning(f"Invalid UUID string: {memory_value}") return None, False elif isinstance(memory_value, int): - # 旧数据存储为 int,需要回退到工作空间默认配置 - logger.warning( - f"Agent 配置中 memory_config_id 为旧格式 int,将使用工作空间默认配置: " - f"value={memory_value}" - ) + resolved = self._resolve_config_id_old(memory_value) + if resolved: + logger.info(f"Resolved legacy config_id_old={memory_value} to config_id={resolved}") + return resolved, False + logger.warning(f"未找到 config_id_old={memory_value} 对应的配置,将使用工作空间默认配置") return None, True else: logger.warning( @@ -963,10 +989,16 @@ class MemoryConfigService: elif isinstance(config_id, str): return uuid.UUID(config_id), False elif isinstance(config_id, int): - # 旧数据存储为 int,需要回退到工作空间默认配置 + resolved = self._resolve_config_id_old(config_id) + if resolved: + logger.info( + f"Resolved workflow legacy config_id_old={config_id} to config_id={resolved}: " + f"node_id={node.get('id')}, node_type={node_type}" + ) + return resolved, False logger.warning( - f"工作流记忆节点 config_id 为旧格式 int,将使用工作空间默认配置: " - f"node_id={node.get('id')}, node_type={node_type}, value={config_id}" + f"未找到工作流记忆节点 config_id_old={config_id} 对应的配置,将使用工作空间默认配置: " + f"node_id={node.get('id')}, node_type={node_type}" ) return None, True else: diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 264ae4df..58f3e8bd 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -241,12 +241,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) except (ValueError, TypeError): config_id_old = None - if config_id_old: - memory_config = config_id_old - else: - memory_config = config.config_id config_dict = { - "config_id": memory_config, + "config_id": str(config.config_id), "config_name": config.config_name, "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, From 7a3220aff52424847cebe1d2d1ca23ed0f78235c Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 24 Mar 2026 12:11:35 +0800 Subject: [PATCH 032/120] chore: Move LICENSE file to project root - Relocate LICENSE from api/ directory to project root - Simplifies license visibility and accessibility for the entire project - Aligns with standard project structure conventions --- api/LICENSE => LICENSE | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename api/LICENSE => LICENSE (100%) diff --git a/api/LICENSE b/LICENSE similarity index 100% rename from api/LICENSE rename to LICENSE From 65b2f9e6e1bd352a8a64735478f66cbfeb1415a4 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 18:57:35 +0800 Subject: [PATCH 033/120] [changes] AI reviews and modifies the code --- api/app/repositories/end_user_repository.py | 4 ++-- api/app/services/user_memory_service.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 8ac65cd7..d8d30618 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -115,8 +115,8 @@ class EndUserRepository: end_user_info = EndUserInfo( end_user_id=end_user.id, other_name=other_name or "", # 如果没有提供 other_name,使用空字符串 - aliases={}, # 空字典而不是 None - meta_data={} # 空字典而不是 None + aliases=[], + meta_data=[] ) self.db.add(end_user_info) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 28dabe10..f6239c76 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -510,7 +510,7 @@ class UserMemoryService: if aliases_updated: try: import asyncio - asyncio.create_task(self._sync_aliases_to_neo4j(end_user_id, update_data['aliases'])) + asyncio.run(self._sync_aliases_to_neo4j(end_user_id, update_data['aliases'])) logger.info(f"已触发 aliases 同步到 Neo4j: end_user_id={end_user_id}, aliases={update_data['aliases']}") except Exception as sync_error: logger.error(f"触发同步 aliases 到 Neo4j 失败: {sync_error}", exc_info=True) From f92eb9f45acf910fb930c931f7ccdd6cd2a30604 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 19:23:31 +0800 Subject: [PATCH 034/120] [changes] Remove the unnecessary prompts --- .../prompt/prompts/extract_triplet.jinja2 | 84 ++----------------- 1 file changed, 5 insertions(+), 79 deletions(-) diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 09e6ff8d..f9f2f45c 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -75,9 +75,8 @@ Extract entities and knowledge triplets from the given statement. **CRITICAL RULES (违反将导致提取失败):** -1. **EVERY entity MUST have BOTH fields:** +1. **EVERY entity MUST have aliases field:** - `"aliases": [...]` - REQUIRED, even if empty `[]` - - `"denied_aliases": [...]` - REQUIRED, even if empty `[]` 2. **ALIASES - 别名提取规则:** {% if language == "zh" %} @@ -100,12 +99,9 @@ Extract entities and knowledge triplets from the given statement. - Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names {% endif %} - -4. **USER ENTITY SPECIAL HANDLING:** + +3. **USER ENTITY SPECIAL HANDLING:** {% if language == "zh" %} - 用户实体的 name 字段:使用 "用户" 或 "我" - 用户的真实姓名:放入 aliases @@ -118,46 +114,24 @@ Extract entities and knowledge triplets from the given statement. * "I'm John" → name="User", aliases=["John"] {% endif %} - -5. **CONFLICT RESOLUTION:** + +4. **ALIASES ORDER:** {% if language == "zh" %} - 顺序优先级:按出现顺序,先出现的在前 {% else %} - Order priority: by appearance order, first mentioned comes first {% endif %} - - **EXAMPLES OF CORRECT EXTRACTION:** {% if language == "zh" %} - "我叫张三" → aliases=["张三"] (张三将成为 other_name) - "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name) - "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name) - - {% else %} - "I'm John" → aliases=["John"] (John will become other_name) - "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name) - "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name) - - {% endif %} - Exclude lengthy quotes, dates, temporal expressions @@ -278,50 +252,7 @@ Output: ] } -**Example 6 (否定别名 - Chinese):** "我不叫陈思远,我其实叫小小张" -Output: -{ - "triplets": [], - "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["小小张"], "denied_aliases": ["陈思远"], "is_explicit_memory": false} - ] -} -**Example 7 (否定别名 - Chinese):** "我不叫远山" -Output: -{ - "triplets": [], - "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "denied_aliases": ["远山"], "is_explicit_memory": false} - ] -} - -**Example 8 (复杂场景 - Chinese):** "大家都叫我明明,我的全名是小明,但我不是小红" -Output: -{ - "triplets": [], - "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["明明", "小明"], "denied_aliases": ["小红"], "is_explicit_memory": false} - ] -} - -**Example 9 (纠正错误 - Chinese):** "我搞错了,我的网名不叫做远山,网名叫做大山" -Output: -{ - "triplets": [], - "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["大山"], "denied_aliases": ["远山"], "is_explicit_memory": false} - ] -} - -**Example 10 (多重纠正 - Chinese):** "其实我不是老张,也不叫小张,我叫张三" -Output: -{ - "triplets": [], - "entities": [ - {"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["张三"], "denied_aliases": ["老张", "小张"], "is_explicit_memory": false} - ] -} {% endif %} ===End of Examples=== @@ -342,9 +273,4 @@ Output: - **⚠️ ALIASES ORDER: preserve temporal order of appearance** - **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []** - - {{ json_schema }} From 30b5db1e98eb6e85c0894b4807fd59371279f9dd Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 25 Mar 2026 21:15:40 +0800 Subject: [PATCH 035/120] [add] migration script --- .../versions/1ea8fe97b5b7_202603252115.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 api/migrations/versions/1ea8fe97b5b7_202603252115.py diff --git a/api/migrations/versions/1ea8fe97b5b7_202603252115.py b/api/migrations/versions/1ea8fe97b5b7_202603252115.py new file mode 100644 index 00000000..1f0df3e7 --- /dev/null +++ b/api/migrations/versions/1ea8fe97b5b7_202603252115.py @@ -0,0 +1,42 @@ +"""202603252115 + +Revision ID: 1ea8fe97b5b7 +Revises: e28bcc212da5 +Create Date: 2026-03-25 21:14:41.825048 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1ea8fe97b5b7' +down_revision: Union[str, None] = 'e28bcc212da5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True)) + op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True)) + op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'status') + op.drop_column('tenants', 'api_ops_rate_limit') + op.drop_column('tenants', 'plan_expired_at') + op.drop_column('tenants', 'plan') + op.drop_column('tenants', 'contact_phone') + op.drop_column('tenants', 'contact_email') + op.drop_column('tenants', 'contact_name') + # ### end Alembic commands ### From b7a03a844f6653b9323185859b5550379c5d477f Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Mar 2026 10:06:05 +0800 Subject: [PATCH 036/120] feat(agent): Opening remarks and document citation function --- api/app/services/app_chat_service.py | 22 +++++++++--- api/app/services/draft_run_service.py | 48 ++++++++++++++++++--------- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 7f91dada..3dda6fc0 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -82,6 +82,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -93,7 +99,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) memory_flag = False if memory: memory_tools, memory_flag = self.agent_service.load_memory_config( @@ -230,7 +237,7 @@ class AppChatService: }), "elapsed_time": elapsed_time, "suggested_questions": suggested_questions, - "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), + "citations": self.agent_service._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -283,6 +290,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -295,7 +308,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -409,7 +423,7 @@ class AppChatService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self.agent_service._filter_citations(features_config, []) + end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector) # 保存消息 human_meta = { diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 09a202cd..d71d5c24 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -190,13 +190,14 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: kb_config: 知识库配置 kb_ids: 知识库ID列表 user_id: 用户ID + citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) Returns: 检索到的相关知识内容 @@ -229,6 +230,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): } ) + # 收集引用信息 + if citations_collector is not None: + seen_doc_ids = {c.get("document_id") for c in citations_collector} + for chunk in retrieve_chunks_result: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations_collector.append({ + "document_id": doc_id, + "file_name": meta.get("file_name", ""), + "knowledge_id": str(meta.get("knowledge_id", "")), + "score": meta.get("score", 0), + }) + return f"检索到以下相关信息:\n\n{context}" else: logger.warning("知识库检索未找到结果") @@ -320,26 +336,26 @@ class AgentRunService: self, knowledge_retrieval_config: dict | None, user_id - ) -> list: + ) -> tuple[list, list]: + """返回 (tools, citations_collector)""" if not knowledge_retrieval_config: - return [] + return [], [] + citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + kb_tool = create_knowledge_retrieval_tool( + knowledge_retrieval_config, kb_ids, user_id, + citations_collector=citations_collector + ) tools.append(kb_tool) - logger.debug( "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } + extra={"kb_ids": kb_ids, "tool_count": len(tools)} ) - return tools + return tools, citations_collector def load_memory_config( self, @@ -549,7 +565,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -680,7 +697,7 @@ class AgentRunService: "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], - "citations": self._filter_citations(features_config, result.get("citations", [])), + "citations": self._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -790,7 +807,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -943,7 +961,7 @@ class AgentRunService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self._filter_citations(features_config, []) + end_data["citations"] = self._filter_citations(features_config, citations_collector) yield self._format_sse_event("end", end_data) logger.info( From 2525f8795c29190b7a882a9ce8f7461350d42a6a Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Mar 2026 10:47:13 +0800 Subject: [PATCH 037/120] feat(agent): Opening remarks and document citation function --- api/app/schemas/app_schema.py | 7 +++++++ api/app/services/draft_run_service.py | 27 ++++++++++++++++----------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 1582d862..e34945eb 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -196,6 +196,13 @@ class CitationConfig(BaseModel): enabled: bool = Field(default=False) +class Citation(BaseModel): + document_id: str + file_name: str + knowledge_id: str + score: float + + class WebSearchConfig(BaseModel): """联网搜索配置""" enabled: bool = Field(default=False) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index d71d5c24..ac34b4de 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository -from app.schemas.app_schema import FileInput +from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service @@ -190,7 +190,7 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: @@ -198,6 +198,11 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec kb_ids: 知识库ID列表 user_id: 用户ID citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) + 列表元素类型为 Citation,包含字段: + - document_id: 文档唯一标识 + - file_name: 文件名 + - knowledge_id: 知识库 ID + - score: 检索相关性得分 Returns: 检索到的相关知识内容 @@ -238,12 +243,12 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec doc_id = meta.get("document_id") or meta.get("doc_id") if doc_id and doc_id not in seen_doc_ids: seen_doc_ids.add(doc_id) - citations_collector.append({ - "document_id": doc_id, - "file_name": meta.get("file_name", ""), - "knowledge_id": str(meta.get("knowledge_id", "")), - "score": meta.get("score", 0), - }) + citations_collector.append(Citation( + document_id=doc_id, + file_name=meta.get("file_name", ""), + knowledge_id=str(meta.get("knowledge_id", "")), + score=meta.get("score", 0) + )) return f"检索到以下相关信息:\n\n{context}" else: @@ -344,7 +349,7 @@ class AgentRunService: citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: kb_tool = create_knowledge_retrieval_tool( knowledge_retrieval_config, kb_ids, user_id, @@ -457,12 +462,12 @@ class AgentRunService: @staticmethod def _filter_citations( features_config: Dict[str, Any], - citations: List[Any] + citations: List[Citation] ) -> List[Any]: """根据 citation 开关决定是否返回引用来源""" citation_cfg = features_config.get("citation", {}) if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): - return citations + return [cit.model_dump() for cit in citations] return [] async def run( From f2c9902a07859910eeefe1db34ba96589a92b05e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 25 Mar 2026 19:19:44 +0800 Subject: [PATCH 038/120] [fix] Fix the forgotten periodic tasks --- api/app/services/memory_forget_service.py | 6 ++ api/app/tasks.py | 115 ++++++++++++---------- 2 files changed, 69 insertions(+), 52 deletions(-) diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index a0bcc1a1..11118571 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -315,6 +315,12 @@ class MemoryForgetService: # 获取遗忘引擎组件 _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + # 如果参数为 None,使用配置中的默认值 + if max_merge_batch_size is None: + max_merge_batch_size = config.get('max_merge_batch_size', 100) + if min_days_since_access is None: + min_days_since_access = config.get('min_days_since_access', 30) + # 记录执行开始时间 execution_time = datetime.now() diff --git a/api/app/tasks.py b/api/app/tasks.py index 3b81ced3..61736275 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -36,9 +36,11 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models import Document, File, Knowledge +from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema from app.schemas.model_schema import ModelInfo -from app.services.memory_agent_service import MemoryAgentService +from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config +from app.services.memory_forget_service import MemoryForgetService from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id from app.utils.redis_lock import RedisLock @@ -1860,7 +1862,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: @celery_app.task( name="app.tasks.run_forgetting_cycle_task", bind=True, - ignore_result=True, + ignore_result=False, # 改为 False 以便在 Flower 中查看结果 max_retries=0, acks_late=False, time_limit=7200, @@ -1868,68 +1870,77 @@ def workspace_reflection_task(self) -> Dict[str, Any]: ) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 - - 定期执行遗忘周期,识别并融合低激活值的知识节点。 - - Args: - config_id: 配置ID(可选,如果为None则使用默认配置) - - Returns: - 包含任务执行结果的字典 + + 遍历所有终端用户,执行遗忘周期。 """ start_time = time.time() - async def _run() -> Dict[str, Any]: - from app.services.memory_forget_service import MemoryForgetService - + async def _process_users() -> Dict[str, Any]: with get_db_context() as db: - try: - logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + end_users = db.query(EndUser).all() + if not end_users: + logger.info("没有终端用户,跳过遗忘周期") + return {"status": "SUCCESS", "message": "没有终端用户", + "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, + "duration_seconds": time.time() - start_time} - forget_service = MemoryForgetService() + logger.info(f"开始处理 {len(end_users)} 个终端用户的遗忘周期") + forget_service = MemoryForgetService() + total_merged = total_failed = processed_users = 0 + failed_users = [] - # 运行遗忘周期 - # FIXME: MemeoryForgetService - report = await forget_service.trigger_forgetting( - db=db, - end_user_id=None, # 处理所有组 - config_id=config_id - ) + for end_user in end_users: + try: + # 获取用户配置(自动回退到工作空间默认配置) + connected_config = get_end_user_connected_config(str(end_user.id), db) + user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) + + if not user_config_id: + failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) + continue - duration = time.time() - start_time + # 执行遗忘周期 + report = await forget_service.trigger_forgetting_cycle( + db=db, end_user_id=str(end_user.id), config_id=user_config_id + ) + + total_merged += report.get('merged_count', 0) + total_failed += report.get('failed_count', 0) + processed_users += 1 + + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") + + except Exception as e: + logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) + failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) - logger.info( - f"遗忘周期定时任务完成: " - f"融合 {report['merged_count']} 对节点, " - f"失败 {report['failed_count']} 对, " - f"耗时 {duration:.2f} 秒" - ) + duration = time.time() - start_time + logger.info(f"遗忘周期完成: {processed_users}/{len(end_users)} 用户, " + f"融合 {total_merged} 对, 耗时 {duration:.2f}s") - return { - "status": "SUCCESS", - "message": "遗忘周期执行成功", - "report": report, - "duration_seconds": duration - } - - except Exception as e: - duration = time.time() - start_time - logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) - - return { - "status": "FAILED", - "message": f"遗忘周期执行失败: {str(e)}", - "duration_seconds": duration - } + return { + "status": "SUCCESS", + "message": f"处理 {processed_users} 个用户", + "report": { + "merged_count": total_merged, + "failed_count": total_failed, + "processed_users": processed_users, + "total_users": len(end_users), + "failed_users": failed_users + }, + "duration_seconds": duration + } # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_process_users()) + except Exception as e: + logger.error(f"遗忘周期任务失败: {e}", exc_info=True) + return { + "status": "FAILED", + "message": f"任务失败: {str(e)}", + "duration_seconds": time.time() - start_time + } # ============================================================================= From 9d2f3aa8f900cd1a168fac647aa0826d71964f97 Mon Sep 17 00:00:00 2001 From: wxy Date: Thu, 26 Mar 2026 11:50:36 +0800 Subject: [PATCH 039/120] feat: version introduction support db source with json fallback --- api/app/repositories/home_page_repository.py | 63 +++++++++++++++++++- api/app/services/home_page_service.py | 27 ++++++--- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py index bcb3b622..6d74bcaf 100644 --- a/api/app/repositories/home_page_repository.py +++ b/api/app/repositories/home_page_repository.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import func from uuid import UUID -from typing import Dict +from typing import Dict, Optional, Any from app.models.end_user_model import EndUser from app.models.user_model import User @@ -190,4 +190,63 @@ class HomePageRepository: user_count_dict = {workspace_id: count for workspace_id, count in user_counts} - return workspaces, app_count_dict, user_count_dict \ No newline at end of file + return workspaces, app_count_dict, user_count_dict + + @staticmethod + def get_version_introduction(db: Session, version: str) -> Optional[Dict[str, Any]]: + """ + 从数据库获取版本说明(优先读取已发布的版本) + 使用反射方式读取表结构,不依赖 premium 模型类 + + Args: + db: 数据库会话 + version: 版本号,如 "v0.2.7" + + Returns: + 版本说明字典,格式与 version_info.json 一致 + 如果数据库中没有该版本,返回 None + """ + try: + from sqlalchemy import Table, MetaData + + metadata = MetaData() + version_notes = Table('version_notes', metadata, autoload_with=db.engine) + version_note_items = Table('version_note_items', metadata, autoload_with=db.engine) + + note = db.query(version_notes).filter( + version_notes.c.version == version, + version_notes.c.is_published == True + ).first() + + if not note: + return None + + items = db.query(version_note_items).filter( + version_note_items.c.note_id == note.id + ).order_by(version_note_items.c.sort_order).all() + + core_upgrades = [] + for item in items: + title = item.title + content = item.content + if content: + core_upgrades.append(f"{title}
{content}") + else: + core_upgrades.append(title) + + return { + "introduction": { + "codeName": "", + "releaseDate": note.release_date.isoformat() if note.release_date else "", + "upgradePosition": "", + "coreUpgrades": core_upgrades + }, + "introduction_en": { + "codeName": "", + "releaseDate": note.release_date.isoformat() if note.release_date else "", + "upgradePosition": "", + "coreUpgrades": core_upgrades + } + } + except Exception: + return None \ No newline at end of file diff --git a/api/app/services/home_page_service.py b/api/app/services/home_page_service.py index 8326ad40..4e6bf664 100644 --- a/api/app/services/home_page_service.py +++ b/api/app/services/home_page_service.py @@ -94,29 +94,38 @@ class HomePageService: @staticmethod def load_version_introduction(version: str) -> Dict[str, Any]: """ - 从 JSON 文件加载对应版本的介绍 + 加载对应版本的介绍(优先从数据库读取,fallback 到 JSON 文件) :param version: 系统版本号(如 "0.2.0") :return: 对应版本的详细介绍 """ - # 2. 定义 JSON 文件路径(简化路径处理,保留绝对路径调试特性) + from copy import deepcopy + from app.db import SessionLocal + from app.repositories.home_page_repository import HomePageRepository + + result = deepcopy(HomePageService.DEFAULT_RETURN_DATA) + + try: + db = SessionLocal() + try: + db_result = HomePageRepository.get_version_introduction(db, version) + if db_result: + return db_result + finally: + db.close() + except Exception as e: + pass + json_abs_path = Path(__file__).parent.parent / "version_info.json" json_abs_path = json_abs_path.resolve() - # 3. 初始化返回结果(深拷贝默认模板,避免修改原常量) - from copy import deepcopy - result = deepcopy(HomePageService.DEFAULT_RETURN_DATA) - try: - # 4. 简化文件存在性判断(合并逻辑,减少分支) if not json_abs_path.exists(): result["message"] = f"版本介绍文件不存在:{json_abs_path}" return result - # 5. 读取并解析 JSON 文件(简化文件操作流程) with open(json_abs_path, "r", encoding="utf-8") as f: changelogs = json.load(f) - # 6. 简化版本匹配逻辑,直接返回结果或更新提示信息 if version in changelogs: return changelogs[version] result["message"] = f"暂未查询到 {version} 版本的详细介绍" From 4d4a780ab7b60254ba0422b74cd30a10d8e86718 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 26 Mar 2026 12:05:53 +0800 Subject: [PATCH 040/120] style(memory): Pref an anomaly in the message null check logic. --- .../core/memory/agent/langgraph_graph/routing/write_router.py | 2 +- api/app/services/memory_agent_service.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 6176caf5..2074b6ca 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): logger.info('写入长期记忆NEO4J') - formatted_messages = (redis_messages) + formatted_messages = redis_messages # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index e5c34492..289fd74c 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -357,6 +357,7 @@ class MemoryAgentService: if file_object is None: continue message["file_content"].append((file_object, file["type"])) + logger.info(messages) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: @@ -606,7 +607,7 @@ class MemoryAgentService: retrieved_content.append({query: statements}) # 如果 retrieved_content 为空,设置为空字符串 - if retrieved_content == []: + if not retrieved_content: retrieved_content = '' # 只有当回答不是"信息不足"且不是快速检索时才保存 From 863be50aafef5fd4d29b3ef040913fee9c9e7ae6 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 15:03:33 +0800 Subject: [PATCH 041/120] [changes] Spatial verification, retrieval synchronization --- .../controllers/user_memory_controllers.py | 23 +++++++++++++++++++ .../extraction_orchestrator.py | 5 ++-- .../repositories/end_user_info_repository.py | 7 +++--- api/app/repositories/end_user_repository.py | 4 ++-- api/app/services/user_memory_service.py | 14 ++++------- 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index b0dc82a0..10b396a7 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -23,6 +23,7 @@ from app.services.memory_entity_relationship_service import MemoryEntityService, from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository +from app.repositories.end_user_repository import EndUserRepository from app.schemas.end_user_info_schema import ( EndUserInfoResponse, EndUserInfoCreate, @@ -361,6 +362,17 @@ async def get_end_user_info( f"workspace={workspace_id}" ) + # 校验 end_user 是否属于当前工作空间 + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_end_user_by_id(end_user_id) + if end_user is None: + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found") + if str(end_user.workspace_id) != str(workspace_id): + api_logger.warning( + f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}" + ) + return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch") + result = user_memory_service.get_end_user_info(db, end_user_id) if result["success"]: @@ -409,6 +421,17 @@ async def update_end_user_info( f"workspace={workspace_id}" ) + # 校验 end_user 是否属于当前工作空间 + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_end_user_by_id(end_user_id) + if end_user is None: + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found") + if str(end_user.workspace_id) != str(workspace_id): + api_logger.warning( + f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}" + ) + return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch") + # 获取更新数据(排除 end_user_id) update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index d5681da9..58a4c441 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1389,9 +1389,8 @@ class ExtractionOrchestrator: logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}") # 更新或创建 end_user_info 记录 - existing_infos = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) - if existing_infos: - info = existing_infos[0] + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + if info: new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases) if new_name_info is not None: info.other_name = new_name_info diff --git a/api/app/repositories/end_user_info_repository.py b/api/app/repositories/end_user_info_repository.py index f9f4665c..f627b46f 100644 --- a/api/app/repositories/end_user_info_repository.py +++ b/api/app/repositories/end_user_info_repository.py @@ -35,9 +35,10 @@ class EndUserInfoRepository: """根据ID获取用户信息""" return self.db.query(EndUserInfo).filter(EndUserInfo.id == info_id).first() - def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[EndUserInfo]: - """获取用户的所有信息记录""" - return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all() + + def get_by_end_user_id(self, end_user_id: uuid.UUID) -> Optional[EndUserInfo]: + """获取用户的信息记录""" + return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).first() def update(self, info_id: uuid.UUID, aliases: List[str] = None, meta_data: dict = None) -> Optional[EndUserInfo]: """更新用户信息""" diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index d8d30618..3c1dd16f 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -115,8 +115,8 @@ class EndUserRepository: end_user_info = EndUserInfo( end_user_id=end_user.id, other_name=other_name or "", # 如果没有提供 other_name,使用空字符串 - aliases=[], - meta_data=[] + aliases=[], + meta_data={} ) self.db.add(end_user_info) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index f6239c76..942e01a0 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -387,9 +387,9 @@ class UserMemoryService: # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) + end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_records: + if not end_user_info_record: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -397,9 +397,6 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } - # 获取第一条记录 - end_user_info_record = end_user_info_records[0] - # 构建响应数据(转换时间为毫秒时间戳) response_data = { "end_user_info_id": str(end_user_info_record.id), @@ -462,9 +459,9 @@ class UserMemoryService: # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) + end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_records: + if not end_user_info_record: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -472,9 +469,6 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } - # 获取第一条记录 - end_user_info_record = end_user_info_records[0] - # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} From 477853b04ebaf5eed71410a0eebed979d914e21a Mon Sep 17 00:00:00 2001 From: wxy Date: Thu, 26 Mar 2026 15:45:16 +0800 Subject: [PATCH 042/120] feat: Add feature_billing and feature_user_management fields to tenant model --- api/app/models/tenant_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 044857d2..75a480b1 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -24,6 +24,10 @@ class Tenants(Base): default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 + # 租户功能开关字段 + feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单") + feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单") + # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") From 68489f1b289e4b27532b06847acd1566084f4ab9 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Mar 2026 16:05:24 +0800 Subject: [PATCH 043/120] feat(workflow): Document extraction node --- .../nodes/document_extractor/__init__.py | 4 + .../nodes/document_extractor/config.py | 23 ++++ .../workflow/nodes/document_extractor/node.py | 101 ++++++++++++++++++ api/app/core/workflow/nodes/enums.py | 1 + api/app/core/workflow/nodes/node_factory.py | 5 +- api/app/services/multimodal_service.py | 65 ----------- 6 files changed, 133 insertions(+), 66 deletions(-) create mode 100644 api/app/core/workflow/nodes/document_extractor/__init__.py create mode 100644 api/app/core/workflow/nodes/document_extractor/config.py create mode 100644 api/app/core/workflow/nodes/document_extractor/node.py diff --git a/api/app/core/workflow/nodes/document_extractor/__init__.py b/api/app/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 00000000..c51bc2c0 --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .config import DocExtractorNodeConfig +from .node import DocExtractorNode + +__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"] diff --git a/api/app/core/workflow/nodes/document_extractor/config.py b/api/app/core/workflow/nodes/document_extractor/config.py new file mode 100644 index 00000000..dd946422 --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/config.py @@ -0,0 +1,23 @@ +from pydantic import Field +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class DocExtractorNodeConfig(BaseNodeConfig): + file_selector: str = Field( + ..., + description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}" + ) + output_format: str = Field( + default="text", + description="Output format: 'text' | 'markdown'" + ) + + class Config: + json_schema_extra = { + "examples": [ + { + "file_selector": "{{ sys.files }}", + "output_format": "text" + } + ] + } diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 00000000..050f693f --- /dev/null +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,101 @@ +import logging +from typing import Any + +from app.core.workflow.engine.state_manager import WorkflowState +from app.core.workflow.engine.variable_pool import VariablePool +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.schemas.app_schema import FileInput, FileType, TransferMethod + +logger = logging.getLogger(__name__) + + +def _file_object_to_file_input(f: FileObject) -> FileInput: + """Convert workflow FileObject to multimodal FileInput.""" + return FileInput( + type=FileType.DOCUMENT, + transfer_method=TransferMethod(f.transfer_method), + url=f.url or None, + upload_file_id=f.file_id or None, + file_type=f.origin_file_type or "", + ) + + +def _normalise_files(val: Any) -> list[FileObject]: + if isinstance(val, FileObject): + return [val] + if isinstance(val, dict) and val.get("is_file"): + return [FileObject(**val)] + if isinstance(val, list): + result = [] + for item in val: + if isinstance(item, FileObject): + result.append(item) + elif isinstance(item, dict) and item.get("is_file"): + result.append(FileObject(**item)) + return result + return [] + + +class DocExtractorNode(BaseNode): + """Document Extractor Node. + + Reads one or more file variables and extracts their text content + by delegating to MultimodalService._extract_document_text. + + Outputs: + text (string) – full concatenated text of all input files + chunks (array[string]) – per-file extracted text + """ + + def _output_types(self) -> dict[str, VariableType]: + return { + "text": VariableType.STRING, + "chunks": VariableType.ARRAY_STRING, + } + + def _extract_output(self, business_result: Any) -> Any: + return business_result + + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return {"file_selector": self.config.get("file_selector")} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: + config = DocExtractorNodeConfig(**self.config) + + raw_val = self.get_variable(config.file_selector, variable_pool, strict=False) + if raw_val is None: + logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty") + return {"text": "", "chunks": []} + + files = _normalise_files(raw_val) + if not files: + return {"text": "", "chunks": []} + + chunks: list[str] = [] + with get_db_read() as db: + from app.services.multimodal_service import MultimodalService + svc = MultimodalService(db) + for f in files: + try: + file_input = _file_object_to_file_input(f) + # Ensure URL is populated for local files + if not file_input.url: + file_input.url = await svc.get_file_url(file_input) + # Reuse cached bytes if already fetched + if f.get_content(): + file_input.set_content(f.get_content()) + text = await svc._extract_document_text(file_input) + chunks.append(text) + except Exception as e: + logger.error( + f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}", + exc_info=True, + ) + chunks.append("") + + full_text = "\n\n".join(c for c in chunks if c) + logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}") + return {"text": full_text, "chunks": chunks} diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 5a603ac9..529cd0b3 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -23,6 +23,7 @@ class NodeType(StrEnum): BREAK = "break" MEMORY_READ = "memory-read" MEMORY_WRITE = "memory-write" + DOCUMENT_EXTRACTOR = "document-extractor" UNKNOWN = "unknown" NOTES = "notes" diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 9e5a7d24..49add867 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.tool import ToolNode +from app.core.workflow.nodes.document_extractor import DocExtractorNode logger = logging.getLogger(__name__) @@ -49,7 +50,8 @@ WorkflowNode = Union[ ToolNode, MemoryReadNode, MemoryWriteNode, - CodeNode + CodeNode, + DocExtractorNode ] @@ -81,6 +83,7 @@ class NodeFactory: NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.CODE: CodeNode, + NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode } @classmethod diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 3afd6206..4cf3d89d 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -403,71 +403,6 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - async def history_process_files( - self, - files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: - """ - 处理文件列表,返回 LLM 可用的格式 - - Args: - files: 文件输入列表 - - Returns: - List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式) - """ - if not files: - return [] - - # 获取对应的策略 - # dashscope 的 omni 模型使用 OpenAI 兼容格式 - if self.provider == "dashscope" and self.is_omni: - strategy_class = OpenAIFormatStrategy - else: - strategy_class = PROVIDER_STRATEGIES.get(self.provider) - if not strategy_class: - logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") - strategy_class = DashScopeFormatStrategy - - result = [] - for idx, file in enumerate(files): - strategy = strategy_class(file) - if not file.url: - file.url = await self.get_file_url(file) - try: - if file.type == FileType.IMAGE and "vision" in self.capability: - is_support, content = await self._process_image(file, strategy) - result.append(content) - elif file.type == FileType.DOCUMENT: - is_support, content = await self._process_document(file, strategy) - result.append(content) - elif file.type == FileType.AUDIO and "audio" in self.capability: - is_support, content = await self._process_audio(file, strategy) - result.append(content) - elif file.type == FileType.VIDEO and "video" in self.capability: - is_support, content = await self._process_video(file, strategy) - result.append(content) - else: - logger.warning(f"不支持的文件类型: {file.type}") - except Exception as e: - logger.error( - f"处理文件失败", - extra={ - "file_index": idx, - "file_type": file.type, - "error": str(e) - }, - exc_info=True - ) - # 继续处理其他文件,不中断整个流程 - result.append({ - "type": "text", - "text": f"[文件处理失败: {str(e)}]" - }) - - logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") - return result - async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 From 6223b80cc4aff131f4262c33ce7786010d5574b2 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 26 Mar 2026 16:19:01 +0800 Subject: [PATCH 044/120] fix(workflow): Fix LLM node, resolve abnormal field reading issue in message caching functionality --- api/app/core/workflow/nodes/base_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0b31c9e3..8567ebbe 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -639,8 +639,8 @@ class BaseNode(ABC): return content elif isinstance(content, FileObject): - if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): - return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] + if content.content_cache.get(f"{provider}_{api_config.is_omni}"): + return content.content_cache[f"{provider}_{api_config.is_omni}"] with get_db_read() as db: multimodal_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( @@ -656,7 +656,7 @@ class BaseNode(ABC): ) content.set_content(file_obj.get_content()) if message: - content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message + content.content_cache[f"{provider}_{api_config.is_omni}"] = message return message return None raise TypeError(f'Unexpected input value type - {type(content)}') From 1df3fc416ade2bd39b165f189b8c1e95e727e712 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Mar 2026 16:19:40 +0800 Subject: [PATCH 045/120] feat(workflow): Document extraction node --- api/app/core/workflow/nodes/document_extractor/config.py | 7 +------ api/app/core/workflow/nodes/document_extractor/node.py | 4 +++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/api/app/core/workflow/nodes/document_extractor/config.py b/api/app/core/workflow/nodes/document_extractor/config.py index dd946422..69f7f76d 100644 --- a/api/app/core/workflow/nodes/document_extractor/config.py +++ b/api/app/core/workflow/nodes/document_extractor/config.py @@ -7,17 +7,12 @@ class DocExtractorNodeConfig(BaseNodeConfig): ..., description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}" ) - output_format: str = Field( - default="text", - description="Output format: 'text' | 'markdown'" - ) class Config: json_schema_extra = { "examples": [ { - "file_selector": "{{ sys.files }}", - "output_format": "text" + "file_selector": "{{ sys.files }}" } ] } diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index 050f693f..40641f3c 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -29,12 +29,14 @@ def _normalise_files(val: Any) -> list[FileObject]: if isinstance(val, dict) and val.get("is_file"): return [FileObject(**val)] if isinstance(val, list): - result = [] + result: list[FileObject] = [] for item in val: if isinstance(item, FileObject): result.append(item) elif isinstance(item, dict) and item.get("is_file"): result.append(FileObject(**item)) + else: + logger.warning("Ignoring non-file entry in file list for document extractor: %r", item) return result return [] From 2f4f7219e3c3289608b002bad0b05d24583a923f Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 26 Mar 2026 16:29:47 +0800 Subject: [PATCH 046/120] [add] migration script --- .../versions/adaefcbe2aa1_202603261630.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 api/migrations/versions/adaefcbe2aa1_202603261630.py diff --git a/api/migrations/versions/adaefcbe2aa1_202603261630.py b/api/migrations/versions/adaefcbe2aa1_202603261630.py new file mode 100644 index 00000000..b8235dd7 --- /dev/null +++ b/api/migrations/versions/adaefcbe2aa1_202603261630.py @@ -0,0 +1,32 @@ +"""202603261630 + +Revision ID: adaefcbe2aa1 +Revises: 1ea8fe97b5b7 +Create Date: 2026-03-26 16:27:17.590077 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'adaefcbe2aa1' +down_revision: Union[str, None] = '1ea8fe97b5b7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('feature_billing', sa.Boolean(), server_default='false', nullable=False, comment='是否启用收费管理菜单')) + op.add_column('tenants', sa.Column('feature_user_management', sa.Boolean(), server_default='false', nullable=False, comment='是否启用用户管理菜单')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'feature_user_management') + op.drop_column('tenants', 'feature_billing') + # ### end Alembic commands ### From 8c6f395818320b95ed2c00ba2b44f85292392fb9 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Thu, 26 Mar 2026 16:36:53 +0800 Subject: [PATCH 047/120] refactor(app-service): Rename memory config extraction method for clarity - Rename `_extract_memory_config_id` to `_get_memory_config_id_from_release` to better reflect its purpose of retrieving memory config from release objects - Update method call in release creation flow - Update method call in release retrieval flow - Improves code readability by making the method's scope and responsibility more explicit --- api/app/services/app_service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 19aaac42..4dcabff8 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1638,7 +1638,7 @@ class AppService: # ==================== 记忆配置提取方法 ==================== - def _extract_memory_config_id( + def _get_memory_config_id_from_release( self, app_type: str, config: Dict[str, Any] @@ -1863,7 +1863,7 @@ class AppService: self.db.flush() # 先 flush,确保 release 已插入数据库 # 提取记忆配置ID并更新终端用户 - memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) + memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(app.type, config) # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: @@ -2001,7 +2001,7 @@ class AppService: raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") # 提取记忆配置ID并更新终端用户 - memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) + memory_config_id, is_legacy_int = self._get_memory_config_id_from_release(release.type, release.config) # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: From 2319432182a41f2b7bbfc350581bce88fc8720f6 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 17:19:37 +0800 Subject: [PATCH 048/120] [changes] Set up Celery tasks to perform clustering --- api/app/celery_app.py | 3 + .../core/memory/agent/utils/write_tools.py | 50 +++-- .../clustering_engine/label_propagation.py | 199 ++++++++++++------ api/app/tasks.py | 94 +++++++++ 4 files changed, 263 insertions(+), 83 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 58c89f8f..23fd82ed 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -108,6 +108,9 @@ celery_app.conf.update( 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题) + 'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 5829a5c9..55bcb8ba 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -8,6 +8,7 @@ import asyncio import time import uuid from datetime import datetime +from typing import List, Optional from dotenv import load_dotenv @@ -21,7 +22,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -177,28 +178,33 @@ async def write( if success: logger.info("Successfully saved all data to Neo4j") - # 同步用户别名到 PostgreSQL - try: - # 创建一个临时的 orchestrator 实例来调用同步方法 - temp_orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - embedding_id=embedding_model_id - ) - await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs) - logger.info("Successfully synced user aliases to PostgreSQL") - except Exception as sync_error: - logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True) - # 不影响主流程 + # 使用 Celery 异步任务触发聚类(不阻塞主流程) + if all_entity_nodes: + try: + from app.tasks import run_incremental_clustering + + end_user_id = all_entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in all_entity_nodes] + + # 异步提交 Celery 任务 + task = run_incremental_clustering.apply_async( + kwargs={ + "end_user_id": end_user_id, + "new_entity_ids": new_entity_ids, + "llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None, + "embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + }, + # 设置任务优先级(低优先级,不影响主业务) + priority=3, + ) + logger.info( + f"[Clustering] 增量聚类任务已提交到 Celery - " + f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}" + ) + except Exception as e: + # 聚类任务提交失败不影响主流程 + logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True) - # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) - await _trigger_clustering_sync( - all_entity_nodes, - llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str( - memory_config.embedding_model_id) if memory_config.embedding_model_id else None, - ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 0fa6a833..d0b121d7 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -76,6 +76,9 @@ class LabelPropagationEngine: self.repo = CommunityRepository(connector) self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id + # 缓存客户端实例,避免重复初始化 + self._llm_client = None + self._embedder_client = None # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -215,8 +218,17 @@ class LabelPropagationEngine: 3. 若邻居无社区 → 创建新社区 4. 若邻居分属多个社区 → 评估是否合并 """ + # 收集所有需要生成元数据的社区ID + communities_to_update = set() + for entity_id in new_entity_ids: - await self._process_single_entity(entity_id, end_user_id) + cid = await self._process_single_entity(entity_id, end_user_id) + if cid: + communities_to_update.add(cid) + + # 批量生成所有社区的元数据 + if communities_to_update: + await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True) # ────────────────────────────────────────────────────────────────────────── # 内部方法 @@ -224,8 +236,13 @@ class LabelPropagationEngine: async def _process_single_entity( self, entity_id: str, end_user_id: str - ) -> None: - """处理单个新实体的社区分配。""" + ) -> Optional[str]: + """ + 处理单个新实体的社区分配。 + + Returns: + str: 分配到的社区ID(如果有) + """ neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) # 查询自身 embedding(从邻居查询结果中无法获取,需单独查) @@ -237,8 +254,7 @@ class LabelPropagationEngine: await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") - await self._generate_community_metadata([new_cid], end_user_id) - return + return new_cid # 统计邻居社区分布 community_ids_in_neighbors = set( @@ -260,7 +276,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) - await self._generate_community_metadata([new_cid], end_user_id) + return new_cid else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -272,8 +288,8 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - # 新实体加入后成员变化,强制重新生成元数据 - await self._generate_community_metadata([target_cid], end_user_id, force=True) + # 返回目标社区ID,稍后批量生成元数据 + return target_cid async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -456,20 +472,19 @@ class LabelPropagationEngine: self, community_ids: List[str], end_user_id: str, force: bool = False ) -> None: """ - 为一个或多个社区生成并写入元数据。 + 为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。 流程: - 1. 逐个社区调 LLM 生成 name / summary(串行) - 2. 收集所有 summary,一次性批量 embed - 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata + 1. 批量准备所有社区的 prompt + 2. 并发调用 LLM 生成所有社区的 name / summary + 3. 批量 embed 所有 summary + 4. 批量写入数据库 Args: force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后) """ - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - - async def _build_one(cid: str) -> Optional[Dict]: + async def _prepare_one(cid: str) -> Optional[Dict]: + """准备单个社区的数据和 prompt""" try: if not force: check_embedding = bool(self.embedding_model_id) @@ -489,42 +504,32 @@ class LabelPropagationEngine: core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] all_names = [m["name"] for m in members if m.get("name")] + # 默认值 name = "、".join(core_entities[:3]) if core_entities else cid[:8] summary = f"包含实体:{', '.join(all_names)}" + # 准备 LLM prompt(如果配置了 LLM) + prompt = None if self.llm_model_id: - try: - entity_list_str = "\n".join(self._build_entity_lines(members)) - relationships = await self.repo.get_community_relationships(cid, end_user_id) - rel_lines = [ - f"- {r['subject']} → {r['predicate']} → {r['object']}" - for r in relationships - if r.get("subject") and r.get("predicate") and r.get("object") - ] - rel_section = ( - f"\n实体间关系:\n" + "\n".join(rel_lines) - if rel_lines else "" - ) - prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过80个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) - - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - except Exception as e: - logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}") + entity_list_str = "\n".join(self._build_entity_lines(members)) + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过80个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) return { "community_id": cid, @@ -532,14 +537,16 @@ class LabelPropagationEngine: "name": name, "summary": summary, "core_entities": core_entities, + "prompt": prompt, "summary_embedding": None, } except Exception as e: logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True) return None + # --- 阶段1:并发准备所有社区数据 --- results = await asyncio.gather( - *[_build_one(cid) for cid in community_ids], + *[_prepare_one(cid) for cid in community_ids], return_exceptions=True, ) metadata_list = [] @@ -553,19 +560,67 @@ class LabelPropagationEngine: logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}") return - # --- 阶段2:批量生成 summary_embedding --- - if self.embedding_model_id: - try: - summaries = [m["summary"] for m in metadata_list] - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - embeddings = await embedder.response(summaries) - for i, meta in enumerate(metadata_list): - meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None - except Exception as e: - logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) + # --- 阶段2:批量调用 LLM 生成 name 和 summary --- + if self.llm_model_id: + llm_client = self._get_llm_client() + if llm_client: + prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")] + + if prompts_to_process: + logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据") + + async def _call_llm(idx: int, meta: Dict) -> tuple: + """单个 LLM 调用""" + try: + response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}]) + text = response.content if hasattr(response, "content") else str(response) + return (idx, text, None) + except Exception as e: + logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}") + return (idx, None, e) + + # 并发调用所有 LLM 请求 + llm_results = await asyncio.gather( + *[_call_llm(idx, meta) for idx, meta in prompts_to_process], + return_exceptions=True + ) + + # 解析 LLM 响应 + for result in llm_results: + if isinstance(result, Exception): + continue + idx, text, error = result + if error or not text: + continue + + meta = metadata_list[idx] + for line in text.strip().splitlines(): + if line.startswith("名称:"): + meta["name"] = line[3:].strip() + elif line.startswith("摘要:"): + meta["summary"] = line[3:].strip() + + logger.info(f"[Clustering] LLM 批量生成完成") - # --- 阶段3:写入(单个 or 批量)--- + # --- 阶段3:批量生成 summary_embedding --- + if self.embedding_model_id: + embedder = self._get_embedder_client() + if embedder: + try: + summaries = [m["summary"] for m in metadata_list] + logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding") + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + logger.info(f"[Clustering] Embedding 批量生成完成") + except Exception as e: + logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) + + # --- 阶段4:批量写入数据库 --- + # 移除 prompt 字段(不需要存储) + for m in metadata_list: + m.pop("prompt", None) + if len(metadata_list) == 1: m = metadata_list[0] result = await self.repo.update_community_metadata( @@ -582,6 +637,28 @@ class LabelPropagationEngine: ok = await self.repo.batch_update_community_metadata(metadata_list) if not ok: logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败") + else: + logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") + + def _get_llm_client(self): + """获取或创建 LLM 客户端(单例模式)""" + if self._llm_client is None and self.llm_model_id: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}") + return self._llm_client + + def _get_embedder_client(self): + """获取或创建 Embedder 客户端(单例模式)""" + if self._embedder_client is None and self.embedding_model_id: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}") + return self._embedder_client @staticmethod def _new_community_id() -> str: diff --git a/api/app/tasks.py b/api/app/tasks.py index 61736275..d5f09a29 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2627,6 +2627,100 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ # 社区聚类补全任务(触发型) # ============================================================================= +@celery_app.task( + name="app.tasks.run_incremental_clustering", + bind=True, + ignore_result=False, + max_retries=2, + acks_late=True, + time_limit=1800, # 30分钟硬超时 + soft_time_limit=1700, +) +def run_incremental_clustering( + self, + end_user_id: str, + new_entity_ids: List[str], + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, +) -> Dict[str, Any]: + """增量聚类任务:处理新增实体的社区分配和元数据生成。 + + 此任务在后台异步执行,不阻塞 write_message 主流程。 + + Args: + end_user_id: 用户 ID + new_entity_ids: 新增实体 ID 列表 + llm_model_id: LLM 模型 ID(可选) + embedding_model_id: Embedding 模型 ID(可选) + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info( + f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " + f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" + ) + + connector = Neo4jConnector() + try: + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id, + ) + + # 执行增量聚类 + await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) + + logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") + + return { + "status": "SUCCESS", + "end_user_id": end_user_id, + "entity_count": len(new_entity_ids), + } + except Exception as e: + logger.error(f"[IncrementalClustering] 增量聚类失败: {e}", exc_info=True) + raise + finally: + await connector.close() + + try: + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + + logger.info( + f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " + f"elapsed_time={result['elapsed_time']:.2f}s" + ) + + return result + except Exception as e: + elapsed_time = time.time() - start_time + logger.error( + f"[IncrementalClustering] 任务失败 - task_id={self.request.id}, " + f"elapsed_time={elapsed_time:.2f}s, error={str(e)}", + exc_info=True + ) + return { + "status": "FAILURE", + "error": str(e), + "end_user_id": end_user_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id, + } + + @celery_app.task( name="app.tasks.init_community_clustering_for_users", bind=True, From a874cc70a47dfb1bfdbfc8768df88b5e0b4aa72e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 17:32:05 +0800 Subject: [PATCH 049/120] [changes] Add the content for client initialization failure alarm --- .../clustering_engine/label_propagation.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index d0b121d7..246453c0 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -240,8 +240,16 @@ class LabelPropagationEngine: """ 处理单个新实体的社区分配。 + 该函数会为新实体分配社区,可能的情况包括: + 1. 孤立实体(无邻居):创建新的单成员社区 + 2. 邻居都没有社区:创建新社区并将实体和邻居都加入 + 3. 邻居有社区:通过加权投票选择最合适的社区加入 + Returns: - str: 分配到的社区ID(如果有) + Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID, + 但返回类型保留为Optional以支持未来可能的扩展场景 + (例如:实体无法分配到任何社区的情况)。 + 调用方应检查返回值的真假性(truthiness)。 """ neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) @@ -563,6 +571,11 @@ class LabelPropagationEngine: # --- 阶段2:批量调用 LLM 生成 name 和 summary --- if self.llm_model_id: llm_client = self._get_llm_client() + if not llm_client: + logger.warning( + f"[Clustering] LLM 已配置(model_id={self.llm_model_id})但客户端初始化失败," + f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。" + ) if llm_client: prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")] @@ -605,6 +618,11 @@ class LabelPropagationEngine: # --- 阶段3:批量生成 summary_embedding --- if self.embedding_model_id: embedder = self._get_embedder_client() + if not embedder: + logger.warning( + f"[Clustering] Embedding 已配置(model_id={self.embedding_model_id})但客户端初始化失败," + f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。" + ) if embedder: try: summaries = [m["summary"] for m in metadata_list] From 4d39cdf464c214c87967858cbd915377e65d5e82 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Mar 2026 18:28:19 +0800 Subject: [PATCH 050/120] fix(app): The opening remarks and the referenced documents have been saved in the history. --- api/app/services/app_chat_service.py | 69 +++++++++++++++++------ api/app/services/draft_run_service.py | 81 ++++++++++++++++++--------- 2 files changed, 107 insertions(+), 43 deletions(-) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 3dda6fc0..90474428 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -82,12 +82,6 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - # opening_statement:首轮对话注入开场白 - is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) - system_prompt = self.agent_service._inject_opening_statement( - features_config, system_prompt, is_new_conversation - ) - # 准备工具列表 tools = [] @@ -135,7 +129,7 @@ class AppChatService: model_type=ModelType.LLM ) - # 加载历史消息 + # 加载历史消息(包含开场白) history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, max_history=10, @@ -143,6 +137,25 @@ class AppChatService: current_is_omni=api_key_obj.is_omni ) + # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 + is_new_conversation = len(history) == 0 + if is_new_conversation: + opening = self.agent_service._get_opening_statement(features_config, True, variables) + if opening: + self.conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=opening, + meta_data={} + ) + # 重新加载历史(包含刚写入的开场白) + history = await self.conversation_service.get_conversation_history( + conversation_id=conversation_id, + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni + ) + # 处理多模态文件 processed_files = None if files: @@ -184,6 +197,9 @@ class AppChatService: tenant_id=tenant_id, workspace_id=workspace_id ) + # 过滤 citations(只调用一次) + filtered_citations = self.agent_service._filter_citations(features_config, citations_collector) + # 构建用户消息内容(含多模态文件) human_meta = { "files": [], @@ -192,7 +208,8 @@ class AppChatService: assistant_meta = { "model": api_key_obj.model_name, "usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), - "audio_url": None + "audio_url": None, + "citations": filtered_citations } if files: for f in files: @@ -237,7 +254,7 @@ class AppChatService: }), "elapsed_time": elapsed_time, "suggested_questions": suggested_questions, - "citations": self.agent_service._filter_citations(features_config, citations_collector), + "citations": filtered_citations, "audio_url": audio_url, "audio_status": "pending" } @@ -290,12 +307,6 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - # opening_statement:首轮对话注入开场白 - is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) - system_prompt = self.agent_service._inject_opening_statement( - features_config, system_prompt, is_new_conversation - ) - # 准备工具列表 tools = [] @@ -345,7 +356,7 @@ class AppChatService: model_type=ModelType.LLM ) - # 加载历史消息 + # 加载历史消息(包含开场白) history = await self.conversation_service.get_conversation_history( conversation_id=conversation_id, max_history=10, @@ -353,6 +364,25 @@ class AppChatService: current_is_omni=api_key_obj.is_omni ) + # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 + is_new_conversation = len(history) == 0 + if is_new_conversation: + opening = self.agent_service._get_opening_statement(features_config, True, variables) + if opening: + self.conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=opening, + meta_data={} + ) + # 重新加载历史(包含刚写入的开场白) + history = await self.conversation_service.get_conversation_history( + conversation_id=conversation_id, + max_history=10, + current_provider=api_key_obj.provider, + current_is_omni=api_key_obj.is_omni + ) + # 处理多模态文件 processed_files = None if files: @@ -423,7 +453,9 @@ class AppChatService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector) + # 过滤 citations(只调用一次) + filtered_citations = self.agent_service._filter_citations(features_config, citations_collector) + end_data["citations"] = filtered_citations # 保存消息 human_meta = { @@ -433,7 +465,8 @@ class AppChatService: assistant_meta = { "model": api_key_obj.model_name, "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}, - "audio_url": None + "audio_url": None, + "citations": filtered_citations } if files: diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ac34b4de..e188872f 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -445,19 +445,27 @@ class AgentRunService: ) @staticmethod - def _inject_opening_statement( + def _get_opening_statement( features_config: Dict[str, Any], - system_prompt: str, - is_new_conversation: bool - ) -> str: - """首轮对话时将开场白注入 system_prompt""" + is_new_conversation: bool, + variables: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """首轮对话时返回开场白文本(支持变量替换),否则返回 None""" if not is_new_conversation: - return system_prompt + return None opening = features_config.get("opening_statement", {}) if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): - return system_prompt + return None + statement = opening["statement"] - return f"{system_prompt}\n\n[对话开场白]\n{statement}" + + # 如果有变量,进行替换(仅支持 {{var_name}} 格式) + if variables: + for var_name, var_value in variables.items(): + placeholder = f"{{{{{var_name}}}}}" + statement = statement.replace(placeholder, str(var_value)) + + return statement @staticmethod def _filter_citations( @@ -555,10 +563,6 @@ class AgentRunService: # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - # opening_statement:首轮对话注入开场白 - is_new_conversation = not conversation_id - system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation) - # 4. 准备工具列表 tools = [] @@ -593,12 +597,15 @@ class AgentRunService: tools=tools, ) - # 5. 处理会话ID(创建或验证) + # 5. 处理会话ID(创建或验证),新会话时写入开场白 + is_new_conversation = not conversation_id + opening = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + opening_statement=opening ) model_info = ModelInfo( @@ -611,7 +618,7 @@ class AgentRunService: model_type=model_config.type ) - # 6. 加载历史消息 + # 6. 加载历史消息(包含开场白) history = await self._load_conversation_history( conversation_id=conversation_id, max_history=10, @@ -668,6 +675,9 @@ class AgentRunService: tenant_id=tenant_id, workspace_id=workspace_id ) if not sub_agent else None + # 过滤 citations(只调用一次) + filtered_citations = self._filter_citations(features_config, citations_collector) + # 10. 保存会话消息 if not sub_agent: await self._save_conversation_message( @@ -686,6 +696,7 @@ class AgentRunService: files=files, processed_files=processed_files, audio_url=audio_url, + citations=filtered_citations, provider=api_key_config.get("provider"), is_omni=api_key_config.get("is_omni", False) ) @@ -702,7 +713,7 @@ class AgentRunService: "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], - "citations": self._filter_citations(features_config, citations_collector), + "citations": filtered_citations, "audio_url": audio_url, "audio_status": "pending" } @@ -797,10 +808,6 @@ class AgentRunService: # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - # opening_statement:首轮对话注入开场白 - is_new_conversation = not conversation_id - system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation) - # 4. 准备工具列表 tools = [] @@ -836,13 +843,16 @@ class AgentRunService: streaming=True ) - # 5. 处理会话ID(创建或验证) + # 5. 处理会话ID(创建或验证),新会话时写入开场白 + is_new_conversation = not conversation_id + opening = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, - sub_agent=sub_agent + sub_agent=sub_agent, + opening_statement=opening ) model_info = ModelInfo( @@ -926,6 +936,9 @@ class AgentRunService: if sub_agent: yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens}) + # 过滤 citations(只调用一次) + filtered_citations = self._filter_citations(features_config, citations_collector) + # 11. 保存会话消息 if not sub_agent: await self._save_conversation_message( @@ -940,6 +953,7 @@ class AgentRunService: files=files, processed_files=processed_files, audio_url=stream_audio_url, + citations=filtered_citations, provider=api_key_config.get("provider"), is_omni=api_key_config.get("is_omni", False) ) @@ -966,7 +980,7 @@ class AgentRunService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self._filter_citations(features_config, citations_collector) + end_data["citations"] = filtered_citations yield self._format_sse_event("end", end_data) logger.info( @@ -1046,7 +1060,8 @@ class AgentRunService: app_id: uuid.UUID, workspace_id: uuid.UUID, user_id: Optional[str], - sub_agent: bool = False + sub_agent: bool = False, + opening_statement: Optional[str] = None ) -> str: """确保会话存在(创建或验证) @@ -1055,6 +1070,8 @@ class AgentRunService: app_id: 应用ID workspace_id: 工作空间ID(必须) user_id: 用户ID + sub_agent: 是否为子代理 + opening_statement: 开场白(新会话时作为第一条消息写入) Returns: str: 会话ID @@ -1092,6 +1109,16 @@ class AgentRunService: self.db.commit() self.db.refresh(new_conversation) + # 如果有开场白,作为第一条 assistant 消息写入数据库 + if opening_statement: + conversation_service.add_message( + conversation_id=uuid.UUID(new_conv_id), + role="assistant", + content=opening_statement, + meta_data={} + ) + logger.debug(f"已保存开场白到会话 {new_conv_id}") + logger.info( "创建草稿会话成功", extra={ @@ -1215,6 +1242,7 @@ class AgentRunService: files: Optional[List[FileInput]] = None, processed_files: Optional[List[Dict[str, Any]]] = None, audio_url: Optional[str] = None, + citations: Optional[List[Any]] = None, provider: Optional[str] = None, is_omni: Optional[bool] = None ) -> None: @@ -1230,6 +1258,7 @@ class AgentRunService: files: 原始文件输入 processed_files: 处理后的文件 audio_url: 音频URL + citations: 引用来源列表 provider: 模型供应商 is_omni: 是否为全模态模型 """ @@ -1266,9 +1295,11 @@ class AgentRunService: content=user_message, meta_data=human_meta ) - # 保存助手消息(含 audio_url) + # 保存助手消息(含 audio_url 和 citations) if audio_url: meta_data["audio_url"] = audio_url + if citations: + meta_data["citations"] = citations conversation_service.add_message( conversation_id=conv_uuid, role="assistant", From b35bedc7305e0701c5d508702ae4b393d1c23f48 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 18:30:59 +0800 Subject: [PATCH 051/120] [changes] New field added --- api/app/models/end_user_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index a821680f..ff46786a 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -22,6 +22,14 @@ class EndUser(Base): created_at = Column(DateTime, default=datetime.datetime.now) updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) + # 用户档案字段 - User Profile Fields + position = Column(String, nullable=True, comment="职位") + department = Column(String, nullable=True, comment="部门") + contact = Column(String, nullable=True, comment="联系方式") + phone = Column(String, nullable=True, comment="电话") + hire_date = Column(DateTime, nullable=True, comment="入职日期") + updatetime_profile = Column(DateTime, nullable=True, comment="核心档案信息最后更新时间") + memory_config_id = Column( UUID(as_uuid=True), ForeignKey("memory_config.config_id"), From 3d291e3c23d65f285e7962891eb8e2c15f230d3e Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 26 Mar 2026 18:34:19 +0800 Subject: [PATCH 052/120] [add] migration script --- .../versions/1480a7d680fb_202603261815.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 api/migrations/versions/1480a7d680fb_202603261815.py diff --git a/api/migrations/versions/1480a7d680fb_202603261815.py b/api/migrations/versions/1480a7d680fb_202603261815.py new file mode 100644 index 00000000..4c6f8c9c --- /dev/null +++ b/api/migrations/versions/1480a7d680fb_202603261815.py @@ -0,0 +1,59 @@ +"""202603261815 + +Revision ID: 1480a7d680fb +Revises: adaefcbe2aa1 +Create Date: 2026-03-26 18:16:07.886033 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '1480a7d680fb' +down_revision: Union[str, None] = 'adaefcbe2aa1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('end_user_info', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('end_user_id', sa.UUID(), nullable=False, comment='关联的终端用户ID'), + sa.Column('other_name', sa.String(), nullable=False, comment='关联的用户名称'), + sa.Column('aliases', sa.ARRAY(sa.String()), nullable=True, comment='用户别名列表(字符串数组)'), + sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='用户相关的扩展信息(JSON格式)'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'), + sa.ForeignKeyConstraint(['end_user_id'], ['end_users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_end_user_info_end_user_id'), 'end_user_info', ['end_user_id'], unique=False) + op.create_index(op.f('ix_end_user_info_id'), 'end_user_info', ['id'], unique=False) + + connection = op.get_bind() + connection.execute(sa.text(""" + INSERT INTO end_user_info (id, end_user_id, other_name, aliases, meta_data, created_at, updated_at) + SELECT + gen_random_uuid() as id, + id as end_user_id, + other_name, + '{}'::TEXT[] as aliases, + NULL as meta_data, + NOW() as created_at, + NOW() as updated_at + FROM end_users + """)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_end_user_info_id'), table_name='end_user_info') + op.drop_index(op.f('ix_end_user_info_end_user_id'), table_name='end_user_info') + op.drop_table('end_user_info') + # ### end Alembic commands ### From 06b823ff96e4505afa3dcb84c95b654138b87e32 Mon Sep 17 00:00:00 2001 From: wxy Date: Thu, 26 Mar 2026 18:48:20 +0800 Subject: [PATCH 053/120] fix: prevent token refresh when tenant is disabled --- api/app/repositories/user_repository.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index b4c11aa4..3f8919aa 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -19,18 +19,22 @@ class UserRepository: self.db = db def get_user_by_id(self, user_id: uuid.UUID) -> Optional[User]: - """根据ID获取用户""" - db_logger.debug(f"根据ID查询用户: user_id={user_id}") + """根据 ID 获取用户(租户禁用时返回 None)""" + db_logger.debug(f"根据 ID 查询用户:user_id={user_id}") try: user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id).first() if user: - db_logger.debug(f"用户查询成功: {user.username} (ID: {user_id})") + # 检查租户状态,租户禁用时返回 None + if user.tenant and not user.tenant.is_active: + db_logger.warning(f"用户 {user.username} (ID: {user_id}) 所属租户 {user.tenant_id} 已被禁用") + return None + db_logger.debug(f"用户查询成功:{user.username} (ID: {user_id})") else: - db_logger.debug(f"用户不存在: user_id={user_id}") + db_logger.debug(f"用户不存在:user_id={user_id}") return user except Exception as e: - db_logger.error(f"根据ID查询用户失败: user_id={user_id} - {str(e)}") + db_logger.error(f"根据 ID 查询用户失败:user_id={user_id} - {str(e)}") raise def get_user_by_email(self, email: str) -> Optional[User]: From 35be03803f8ba8be991638c815cb0c572765b81f Mon Sep 17 00:00:00 2001 From: wxy Date: Thu, 26 Mar 2026 18:56:43 +0800 Subject: [PATCH 054/120] feat: add tenant relationship and status fields to User model --- api/app/models/tenant_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 8f101eb5..a92b5629 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -35,10 +35,6 @@ class Tenants(Base): api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 status = Column(String(50), nullable=True, default='active') # 租户状态 - # 租户功能开关字段 - feature_billing = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用收费管理菜单") - feature_user_management = Column(Boolean, default=False, nullable=False, server_default='false', comment="是否启用用户管理菜单") - # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") From 3ed6f49bb02ac142d23e269e6e0c112612b3455a Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 26 Mar 2026 19:56:31 +0800 Subject: [PATCH 055/120] [add] migration script --- .../versions/6b8a461148ff_202603261955.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 api/migrations/versions/6b8a461148ff_202603261955.py diff --git a/api/migrations/versions/6b8a461148ff_202603261955.py b/api/migrations/versions/6b8a461148ff_202603261955.py new file mode 100644 index 00000000..a0bdac87 --- /dev/null +++ b/api/migrations/versions/6b8a461148ff_202603261955.py @@ -0,0 +1,32 @@ +"""202603261955 + +Revision ID: 6b8a461148ff +Revises: 1480a7d680fb +Create Date: 2026-03-26 19:55:24.041039 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6b8a461148ff' +down_revision: Union[str, None] = '1480a7d680fb' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'feature_user_management') + op.drop_column('tenants', 'feature_billing') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('feature_billing', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False, comment='是否启用收费管理菜单')) + op.add_column('tenants', sa.Column('feature_user_management', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False, comment='是否启用用户管理菜单')) + # ### end Alembic commands ### From a5bce221bd0ea74aa6c1200f9280f319c7d2c2c7 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Thu, 26 Mar 2026 20:12:11 +0800 Subject: [PATCH 056/120] refactor(memory-api): migrate end user creation to authenticated API endpoint - Remove unauthenticated end_user_controller and its router registration - Move end user creation logic to authenticated memory_api_controller endpoint - Add create_end_user method to MemoryAPIService with workspace authorization - Fix retrieve_nodes import in read_graph to use correct function reference - Consolidate end user management under authenticated memory API with API key scoping --- api/app/controllers/__init__.py | 2 - api/app/controllers/end_user_controller.py | 48 ------------------- .../service/memory_api_controller.py | 30 ++++++++++++ .../agent/langgraph_graph/read_graph.py | 6 +-- api/app/services/memory_api_service.py | 47 ++++++++++++++++++ 5 files changed, 80 insertions(+), 53 deletions(-) delete mode 100644 api/app/controllers/end_user_controller.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 869eb039..50e9e0b0 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -14,7 +14,6 @@ from . import ( document_controller, emotion_config_controller, emotion_controller, - end_user_controller, file_controller, file_storage_controller, home_page_controller, @@ -99,6 +98,5 @@ manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) manager_router.include_router(i18n_controller.router) -manager_router.include_router(end_user_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/end_user_controller.py b/api/app/controllers/end_user_controller.py deleted file mode 100644 index b9d54fea..00000000 --- a/api/app/controllers/end_user_controller.py +++ /dev/null @@ -1,48 +0,0 @@ -"""End User 管理接口 - 无需认证""" - -from app.core.logging_config import get_business_logger -from app.core.response_utils import success -from app.db import get_db -from app.repositories.end_user_repository import EndUserRepository -from app.schemas.memory_api_schema import ( - CreateEndUserRequest, - CreateEndUserResponse, -) -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -router = APIRouter(prefix="/end_users", tags=["End Users"]) -logger = get_business_logger() - - -@router.post("") -async def create_end_user( - data: CreateEndUserRequest, - db: Session = Depends(get_db), -): - """ - Create an end user. - - Creates a new end user for the given workspace. - If an end user with the same other_id already exists in the workspace, - returns the existing one. - """ - logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}") - - end_user_repo = EndUserRepository(db) - end_user = end_user_repo.get_or_create_end_user( - app_id=None, - workspace_id=data.workspace_id, - other_id=data.other_id, - ) - - logger.info(f"End user ready: {end_user.id}") - - result = { - "id": str(end_user.id), - "other_id": end_user.other_id or "", - "other_name": end_user.other_name or "", - "workspace_id": str(end_user.workspace_id), - } - - return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 08a94a89..dc5e0408 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -6,6 +6,8 @@ from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.memory_api_schema import ( + CreateEndUserRequest, + CreateEndUserResponse, ListConfigsResponse, MemoryReadRequest, MemoryReadResponse, @@ -113,3 +115,31 @@ async def list_memory_configs( logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + + +@router.post("/end_users") +@require_api_key(scopes=["memory"]) +async def create_end_user( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Create an end user. + + Creates a new end user for the authorized workspace. + If an end user with the same other_id already exists, returns the existing one. + """ + body = await request.json() + payload = CreateEndUserRequest(**body) + logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}") + + memory_api_service = MemoryAPIService(db) + + result = memory_api_service.create_end_user( + workspace_id=api_key_auth.workspace_id, + other_id=payload.other_id, + ) + + logger.info(f"End user ready: {result['id']}") + return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index bddae618..e698e6ad 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -15,7 +15,7 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( Problem_Extension, ) from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( - retrieve, + retrieve_nodes, ) from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( Input_Summary, @@ -53,8 +53,8 @@ async def make_read_graph(): workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Input_Summary", Input_Summary) - # workflow.add_node("Retrieve", retrieve_nodes) - workflow.add_node("Retrieve", retrieve) + workflow.add_node("Retrieve", retrieve_nodes) + # workflow.add_node("Retrieve", retrieve) workflow.add_node("Verify", Verify) workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 9282fc28..f62f526c 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -280,6 +280,53 @@ class MemoryAPIService: code=BizCode.MEMORY_READ_FAILED ) + def create_end_user( + self, + workspace_id: uuid.UUID, + other_id: str, + ) -> Dict[str, Any]: + """Create or retrieve an end user for the workspace. + + Uses get_or_create semantics: if an end user with the same other_id + already exists in the workspace, returns the existing one. + + Args: + workspace_id: Workspace ID from API key authorization + other_id: External user identifier + + Returns: + Dict with id, other_id, other_name, and workspace_id + + Raises: + BusinessException: If creation fails + """ + logger.info(f"Creating end user - other_id: {other_id}, workspace_id: {workspace_id}") + + try: + from app.repositories.end_user_repository import EndUserRepository + + end_user_repo = EndUserRepository(self.db) + end_user = end_user_repo.get_or_create_end_user( + app_id=None, + workspace_id=workspace_id, + other_id=other_id, + ) + + logger.info(f"End user ready: {end_user.id}") + return { + "id": str(end_user.id), + "other_id": end_user.other_id or "", + "other_name": end_user.other_name or "", + "workspace_id": str(end_user.workspace_id), + } + + except Exception as e: + logger.error(f"Failed to create end user for workspace {workspace_id}: {e}") + raise BusinessException( + message=f"Failed to create end user: {str(e)}", + code=BizCode.INTERNAL_ERROR + ) + def list_memory_configs( self, workspace_id: uuid.UUID, From ac7c891ded676ebd90f2f49d84f2ac8989d21d05 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 20:44:55 +0800 Subject: [PATCH 057/120] =?UTF-8?q?=E6=B4=BB=E5=8A=A8=E7=BB=9F=E8=AE=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/cache/memory/activity_stats_cache.py | 210 +++++++++--------- .../core/memory/agent/utils/write_tools.py | 38 ++-- 2 files changed, 124 insertions(+), 124 deletions(-) diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py index 6b162cdd..35c702b1 100644 --- a/api/app/cache/memory/activity_stats_cache.py +++ b/api/app/cache/memory/activity_stats_cache.py @@ -1,124 +1,124 @@ -""" -Recent Activity Stats Cache +# """ +# Recent Activity Stats Cache -记忆提取活动统计缓存模块 -用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 -查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 -""" -import json -import logging -from typing import Optional, Dict, Any -from datetime import datetime +# 记忆提取活动统计缓存模块 +# 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 +# 查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 +# """ +# import json +# import logging +# from typing import Optional, Dict, Any +# from datetime import datetime -from app.aioRedis import aio_redis +# from app.aioRedis import aio_redis -logger = logging.getLogger(__name__) +# logger = logging.getLogger(__name__) -# 缓存过期时间:24小时 -ACTIVITY_STATS_CACHE_EXPIRE = 86400 +# # 缓存过期时间:24小时 +# ACTIVITY_STATS_CACHE_EXPIRE = 86400 -class ActivityStatsCache: - """记忆提取活动统计缓存类""" +# class ActivityStatsCache: +# """记忆提取活动统计缓存类""" - PREFIX = "cache:memory:activity_stats" +# PREFIX = "cache:memory:activity_stats" - @classmethod - def _get_key(cls, workspace_id: str) -> str: - """生成 Redis key +# @classmethod +# def _get_key(cls, workspace_id: str) -> str: +# """生成 Redis key - Args: - workspace_id: 工作空间ID +# Args: +# workspace_id: 工作空间ID - Returns: - 完整的 Redis key - """ - return f"{cls.PREFIX}:by_workspace:{workspace_id}" +# Returns: +# 完整的 Redis key +# """ +# return f"{cls.PREFIX}:by_workspace:{workspace_id}" - @classmethod - async def set_activity_stats( - cls, - workspace_id: str, - stats: Dict[str, Any], - expire: int = ACTIVITY_STATS_CACHE_EXPIRE, - ) -> bool: - """设置记忆提取活动统计缓存 +# @classmethod +# async def set_activity_stats( +# cls, +# workspace_id: str, +# stats: Dict[str, Any], +# expire: int = ACTIVITY_STATS_CACHE_EXPIRE, +# ) -> bool: +# """设置记忆提取活动统计缓存 - Args: - workspace_id: 工作空间ID - stats: 统计数据,格式: - { - "chunk_count": int, - "statements_count": int, - "triplet_entities_count": int, - "triplet_relations_count": int, - "temporal_count": int, - } - expire: 过期时间(秒),默认24小时 +# Args: +# workspace_id: 工作空间ID +# stats: 统计数据,格式: +# { +# "chunk_count": int, +# "statements_count": int, +# "triplet_entities_count": int, +# "triplet_relations_count": int, +# "temporal_count": int, +# } +# expire: 过期时间(秒),默认24小时 - Returns: - 是否设置成功 - """ - try: - key = cls._get_key(workspace_id) - payload = { - "stats": stats, - "generated_at": datetime.now().isoformat(), - "workspace_id": workspace_id, - "cached": True, - } - value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) - logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") - return True - except Exception as e: - logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) - return False +# Returns: +# 是否设置成功 +# """ +# try: +# key = cls._get_key(workspace_id) +# payload = { +# "stats": stats, +# "generated_at": datetime.now().isoformat(), +# "workspace_id": workspace_id, +# "cached": True, +# } +# value = json.dumps(payload, ensure_ascii=False) +# await aio_redis.set(key, value, ex=expire) +# logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") +# return True +# except Exception as e: +# logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) +# return False - @classmethod - async def get_activity_stats( - cls, - workspace_id: str, - ) -> Optional[Dict[str, Any]]: - """获取记忆提取活动统计缓存 +# @classmethod +# async def get_activity_stats( +# cls, +# workspace_id: str, +# ) -> Optional[Dict[str, Any]]: +# """获取记忆提取活动统计缓存 - Args: - workspace_id: 工作空间ID +# Args: +# workspace_id: 工作空间ID - Returns: - 统计数据字典,缓存不存在或已过期返回 None - """ - try: - key = cls._get_key(workspace_id) - value = await aio_redis.get(key) - if value: - payload = json.loads(value) - logger.info(f"命中活动统计缓存: {key}") - return payload - logger.info(f"活动统计缓存不存在或已过期: {key}") - return None - except Exception as e: - logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) - return None +# Returns: +# 统计数据字典,缓存不存在或已过期返回 None +# """ +# try: +# key = cls._get_key(workspace_id) +# value = await aio_redis.get(key) +# if value: +# payload = json.loads(value) +# logger.info(f"命中活动统计缓存: {key}") +# return payload +# logger.info(f"活动统计缓存不存在或已过期: {key}") +# return None +# except Exception as e: +# logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) +# return None - @classmethod - async def delete_activity_stats( - cls, - workspace_id: str, - ) -> bool: - """删除记忆提取活动统计缓存 +# @classmethod +# async def delete_activity_stats( +# cls, +# workspace_id: str, +# ) -> bool: +# """删除记忆提取活动统计缓存 - Args: - workspace_id: 工作空间ID +# Args: +# workspace_id: 工作空间ID - Returns: - 是否删除成功 - """ - try: - key = cls._get_key(workspace_id) - result = await aio_redis.delete(key) - logger.info(f"删除活动统计缓存: {key}, 结果: {result}") - return result > 0 - except Exception as e: - logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) - return False +# Returns: +# 是否删除成功 +# """ +# try: +# key = cls._get_key(workspace_id) +# result = await aio_redis.delete(key) +# logger.info(f"删除活动统计缓存: {key}, 结果: {result}") +# return result > 0 +# except Exception as e: +# logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) +# return False diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 55bcb8ba..c01a36d1 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -260,24 +260,24 @@ async def write( with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - # 将提取统计写入 Redis,按 workspace_id 存储 - try: - from app.cache.memory.activity_stats_cache import ActivityStatsCache + # # 将提取统计写入 Redis,按 workspace_id 存储 + # try: + # from app.cache.memory.activity_stats_cache import ActivityStatsCache - stats_to_cache = { - "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, - "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, - "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, - "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, - "temporal_count": 0, - } - await ActivityStatsCache.set_activity_stats( - workspace_id=str(memory_config.workspace_id), - stats=stats_to_cache, - ) - logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") - except Exception as cache_err: - logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + # stats_to_cache = { + # "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, + # "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, + # "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, + # "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, + # "temporal_count": 0, + # } + # await ActivityStatsCache.set_activity_stats( + # workspace_id=str(memory_config.workspace_id), + # stats=stats_to_cache, + # ) + # logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + # except Exception as cache_err: + # logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) - logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") + # logger.info("=== Pipeline Complete ===") + # logger.info(f"Total execution time: {total_time:.2f} seconds") From c93627750758b51e73b58e431c52d55771b35bce Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Mar 2026 10:15:29 +0800 Subject: [PATCH 058/120] [changes] Annotation Memory --- api/app/cache/memory/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 551062ac..fa9ad1b1 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -4,9 +4,9 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 """ from .interest_memory import InterestMemoryCache -from .activity_stats_cache import ActivityStatsCache +# from .activity_stats_cache import ActivityStatsCache __all__ = [ "InterestMemoryCache", - "ActivityStatsCache", + # "ActivityStatsCache", ] From 4534b65d6a67392c1a29ce0aac0452ec89e9fd8d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 11:56:22 +0800 Subject: [PATCH 059/120] refactor(workflow): optimize workflow history queries and migrate ORM to SQLAlchemy 2.0 - Migrate historical workflow queries from legacy ORM Query API to SQLAlchemy 2.0 select() + execute() - Limit query fields and use pagination to reduce returned data, improving performance - Preserve original ordering and filtering logic --- api/app/repositories/workflow_repository.py | 35 ++++++++------ api/app/services/workflow_service.py | 51 +++++++++++---------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index 4e24faa0..a783fe3f 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -3,9 +3,9 @@ """ import uuid -from typing import Any, Annotated +from typing import Any, Annotated, Literal from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, select from fastapi import Depends from app.models.workflow_model import ( @@ -128,29 +128,36 @@ class WorkflowExecutionRepository: Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.app_id == app_id ).order_by( desc(WorkflowExecution.started_at) - ).limit(limit).offset(offset).all() + ).limit(limit).offset(offset) + return list(self.db.execute(stmt).scalars()) def get_by_conversation_id( self, - conversation_id: uuid.UUID + conversation_id: uuid.UUID, + status: Literal["running", "completed", "failed"] = None, + limit_count: int = 50 ) -> list[WorkflowExecution]: """根据会话 ID 获取执行记录列表 Args: + limit_count: conversation_id: 会话 ID + status: 状态(可选) Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.conversation_id == conversation_id - ).order_by( - desc(WorkflowExecution.started_at) - ).all() + ) + if status: + stmt = stmt.filter(WorkflowExecution.status == status) + stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count) + return list(self.db.execute(stmt).scalars()) def count_by_app_id(self, app_id: uuid.UUID) -> int: """统计应用的执行次数 @@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表(按执行顺序排序) """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id ).order_by( WorkflowNodeExecution.execution_order - ).all() + ) + return list(self.db.execute(stmt).scalars()) def get_by_node_id( self, @@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表 """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.node_id == node_id ).order_by( WorkflowNodeExecution.retry_count - ).all() + ) + return list(self.db.execute(stmt).scalars()) # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index c7d7f2b1..13267078 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -561,6 +561,24 @@ class WorkflowService: storage_type = 'neo4j' return storage_type, user_rag_memory_id + def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None: + executions = self.execution_repo.get_by_conversation_id( + conversation_id=conversation_id, + status="completed", + limit_count=1 + ) + + if executions: + last_state = executions[0].output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + # input_data["conv"] = conv_vars + # input_data["conv_messages"] = last_state.get("messages") or [] + conv_messages = last_state.get("messages") or [] + return conv_vars, conv_messages + return None + # ==================== 工作流执行 ==================== async def run( @@ -634,18 +652,11 @@ class WorkflowService: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break - + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) result = await execute_workflow( @@ -807,17 +818,11 @@ class WorkflowService: storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() async for event in execute_workflow_stream( From 7fd00009a21ed3334f6c72b9068fd3a018ffd637 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 12:00:30 +0800 Subject: [PATCH 060/120] perf(workflow): introduce LazyDict to reduce variable serialization, optimize regex to reduce compilation - Use LazyDict for deferred serialization, improving performance - Reuse regex patterns to avoid repeated compilation --- api/app/core/workflow/engine/state_manager.py | 8 +- api/app/core/workflow/engine/variable_pool.py | 60 ++++++- api/app/core/workflow/nodes/base_node.py | 12 +- .../core/workflow/nodes/cycle_graph/loop.py | 15 +- .../workflow/utils/expression_evaluator.py | 78 ++++----- .../core/workflow/utils/template_renderer.py | 164 +++++++++--------- 6 files changed, 188 insertions(+), 149 deletions(-) diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 2da0d3a8..eed44278 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType def merge_activate_state(x, y): - return { - k: x.get(k, False) or y.get(k, False) - for k in set(x) | set(y) - } + merged = dict(x) + for k, v in y.items(): + merged[k] = merged.get(k, False) or v + return merged def merge_looping_state(x, y): diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 60f1257e..7faca82d 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta logger = logging.getLogger(__name__) +VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}") + + +class LazyVariableDict: + def __init__(self, source, literal): + self._source: dict[str, VariableStruct[Any]] = source + self._literal: bool = literal + self._cache = {} + + def keys(self): + return self._source.keys() + + def _resolve(self, key): + if key in self._cache: + return self._cache[key] + var_struct = self._source.get(key) + if var_struct is None: + raise KeyError(key) + value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value() + self._cache[key] = value + return value + + def get(self, key, default=None): + try: + return self._resolve(key) + except KeyError: + return default + + def __getitem__(self, key): + return self._resolve(key) + + def __getattr__(self, key): + if key.startswith('_'): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + return self._resolve(key) + + def __contains__(self, key): + return key in self._source + + def __iter__(self): + return iter(self._source) + + def __len__(self): + return len(self._source) + class VariableSelector: """变量选择器 @@ -117,8 +162,7 @@ class VariablePool: @staticmethod def transform_selector(selector): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() + variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip() selector = VariableSelector.from_string(variable_literal).path if len(selector) != 2: raise ValueError(f"Selector not valid - {selector}") @@ -303,6 +347,16 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None + def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict: + return LazyVariableDict(self.variables.get(namespace, {}), literal) + + def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]: + return { + ns: LazyVariableDict(vars_dict, literal) + for ns, vars_dict in self.variables.items() + if ns not in ("sys", "conv") + } + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 @@ -479,5 +533,3 @@ class VariablePoolInitializer: var_type=var_type, mut=False ) - - diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8567ebbe..bedf6165 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -552,9 +552,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(literal=True), - node_outputs=variable_pool.get_all_node_outputs(literal=True), - system_vars=variable_pool.get_all_system_vars(literal=True), + conv_vars=variable_pool.lazy_namespace("conv", literal=True), + node_outputs=variable_pool.lazy_all_node_outputs(literal=True), + system_vars=variable_pool.lazy_namespace("sys", literal=True), strict=strict ) @@ -579,9 +579,9 @@ class BaseNode(ABC): return evaluate_condition( expression=expression, - conv_var=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars() + conv_var=variable_pool.lazy_namespace("conv"), + node_outputs=variable_pool.lazy_all_node_outputs(), + system_vars=variable_pool.lazy_namespace("sys") ) @staticmethod diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 84901bad..e555a228 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance -from app.core.workflow.utils.expression_evaluator import evaluate_expression logger = logging.getLogger(__name__) @@ -85,12 +84,7 @@ class LoopRuntime: for variable in self.typed_config.cycle_vars: if variable.input_type == ValueInputType.VARIABLE: - value = evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + value = self.variable_pool.get_value(variable.value) else: value = TypeTransformer.transform(variable.value, variable.type) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) @@ -98,12 +92,7 @@ class LoopRuntime: **self.state ) loopstate["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + variable.name: self.variable_pool.get_value(variable.value) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 4bc5fc4c..05a3294b 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -4,32 +4,33 @@ from typing import Any from simpleeval import simple_eval, NameNotDefined, InvalidExpression +from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class ExpressionEvaluator: """Safe expression evaluator for workflow variables and node outputs.""" - + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @classmethod def normalize_template(cls, template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @classmethod def evaluate( - cls, - expression: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + cls, + expression: str, + conv_vars: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> Any: """ Safely evaluate an expression using workflow variables. @@ -49,48 +50,47 @@ class ExpressionEvaluator: # Remove Jinja2-style brackets if present expression = expression.strip() expression = cls.normalize_template(expression) - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", expression).strip() + expression = VARIABLE_PATTERN.sub(r"\1", expression).strip() # Build context for evaluation context = { - "conv": conv_vars, # conversation variables - "node": node_outputs, # node outputs - "sys": system_vars or {}, # system variables + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - context.update(conv_vars) - context["nodes"] = node_outputs + # context.update(conv_vars) + # context["nodes"] = node_outputs context.update(node_outputs) - + try: # simpleeval supports safe operations: # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result - + except NameNotDefined as e: logger.error(f"Undefined variable in expression: {expression}, error: {e}") raise ValueError(f"Undefined variable: {e}") - + except InvalidExpression as e: logger.error(f"Invalid expression syntax: {expression}, error: {e}") raise ValueError(f"Invalid expression syntax: {e}") - + except SyntaxError as e: logger.error(f"Syntax error in expression: {expression}, error: {e}") raise ValueError(f"Syntax error: {e}") - + except Exception as e: logger.error(f"Expression evaluation failed: {expression}, error: {e}") raise ValueError(f"Expression evaluation failed: {e}") - + @staticmethod def evaluate_bool( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + expression: str, + conv_var: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> bool: """ Evaluate a boolean expression (for conditions). @@ -108,7 +108,7 @@ class ExpressionEvaluator: expression, conv_var, node_outputs, system_vars ) return bool(result) - + @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: """ @@ -121,7 +121,7 @@ class ExpressionEvaluator: list[str]: List of error messages. Empty if all names are valid. """ errors = [] - + for var in variables: var_name = var.get("name", "") @@ -134,16 +134,16 @@ class ExpressionEvaluator: errors.append( f"Variable name '{var_name}' is not a valid Python identifier" ) - + return errors # 便捷函数 def evaluate_expression( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict ) -> Any: """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( @@ -152,11 +152,11 @@ def evaluate_expression( def evaluate_condition( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None -) -> bool: + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict +) -> Any: """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( expression, conv_var, node_outputs, system_vars diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 6a73efc4..bb1e18bf 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -1,7 +1,8 @@ """ -模板渲染器 +Template Renderer -使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 +Provides safe template rendering using Jinja2, supporting variable references +and expressions. """ import logging @@ -10,11 +11,15 @@ from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined +from app.core.workflow.engine.variable_pool import LazyVariableDict + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class SafeUndefined(Undefined): - """访问未定义属性不会报错,返回空字符串""" + """Return empty string instead of raising error when accessing undefined variables""" __slots__ = () def _fail_with_undefined_error(self, *args, **kwargs): @@ -26,26 +31,22 @@ class SafeUndefined(Undefined): class TemplateRenderer: - """模板渲染器""" - def __init__(self, strict: bool = True): - """初始化渲染器 - + """Initialize renderer + Args: - strict: 是否使用严格模式(未定义变量会抛出异常) + strict: Whether to enable strict mode (raise error on undefined variables) """ self.strict = strict self.env = Environment( undefined=StrictUndefined if strict else SafeUndefined, - autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML + autoescape=False # Disable auto-escaping since we handle plain text instead of HTML ) @staticmethod def normalize_template(template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + """Normalize template syntax (convert numeric node reference to dict access)""" + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @@ -53,24 +54,24 @@ class TemplateRenderer: def render( self, template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict | None = None ) -> str: - """渲染模板 - + """Render template + Args: - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Raises: - ValueError: 模板语法错误或变量未定义 - + ValueError: If template syntax is invalid or variables are undefined + Examples: >>> renderer = TemplateRenderer() >>> renderer.render( @@ -80,122 +81,119 @@ class TemplateRenderer: ... {} ... ) 'Hello World!' - + >>> renderer.render( - ... "分析结果: {{node.analyze.output}}", + ... "Analysis result: {{node.analyze.output}}", ... {}, - ... {"analyze": {"output": "正面情绪"}}, + ... {"analyze": {"output": "positive sentiment"}}, ... {} ... ) - '分析结果: 正面情绪' + 'Analysis result: positive sentiment' """ - # 构建命名空间上下文 + # Build namespace context context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": system_vars, # 系统变量:{{sys.execution_id}} + "conv": conv_vars, # Conversation variables: {{conv.user_name}} + "node": node_outputs, # Node outputs: {{node.node_1.output}} + "sys": system_vars, # System variables: {{sys.execution_id}} } - # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} - # 将所有节点输出添加到顶层上下文 + # Allow direct access to node outputs by node ID: {{llm_qa.output}} if node_outputs: context.update(node_outputs) - # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} - if conv_vars: - context.update(conv_vars) - - context["nodes"] = node_outputs or {} # 旧语法兼容 + # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} + # if conv_vars: + # context.update(conv_vars) + # + # context["nodes"] = node_outputs or {} # 旧语法兼容 template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) except TemplateSyntaxError as e: - logger.error(f"模板语法错误: {template}, 错误: {e}") - raise ValueError(f"模板语法错误: {e}") - + logger.error(f"Template syntax error: {template}, error: {e}") + raise ValueError(f"Template syntax error: {e}") except UndefinedError as e: - logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") - + logger.error(f"Undefined variable in template: {template}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except Exception as e: - logger.error(f"模板渲染异常: {template}, 错误: {e}") - raise ValueError(f"模板渲染失败: {e}") + logger.error(f"Template rendering error: {template}, error: {e}") + raise ValueError(f"Template rendering failed: {e}") def validate(self, template: str) -> list[str]: - """验证模板语法 - + """Validate template syntax + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表,如果为空则验证通过 - + List of errors (empty if valid) + Examples: >>> renderer = TemplateRenderer() >>> renderer.validate("Hello {{var.name}}!") [] - - >>> renderer.validate("Hello {{var.name") # 缺少结束标记 - ['模板语法错误: ...'] + + >>> renderer.validate("Hello {{var.name") # Missing closing tag + ['Template syntax error: ...'] """ errors = [] try: self.env.from_string(template) except TemplateSyntaxError as e: - errors.append(f"模板语法错误: {e}") + errors.append(f"Template syntax error: {e}") except Exception as e: - errors.append(f"模板验证失败: {e}") + errors.append(f"Template validation failed: {e}") return errors -# 全局渲染器实例(严格模式) +# Global renderer instances (strict / lenient) _strict_renderer = TemplateRenderer(strict=True) _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any], + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | LazyVariableDict, + system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: - """渲染模板(便捷函数) - + """Render template (convenience function) + Args: - strict: 严格模式 - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + strict: Whether to use strict mode + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Examples: >>> render_template( - ... "请分析: {{var.text}}", - ... {"text": "这是一段文本"}, + ... "Analyze: {{var.text}}", + ... {"text": "This is a text"}, ... {}, ... {} ... ) - '请分析: 这是一段文本' + 'Analyze: This is a text' """ renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: - """验证模板语法(便捷函数) - + """Validate template syntax (convenience function) + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表 + List of errors """ return _strict_renderer.validate(template) From 8ba0a74473ab86e6b8f573f9730104fcbb362f25 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Mar 2026 12:03:48 +0800 Subject: [PATCH 061/120] [changes] Specified element quantity --- .../extraction_orchestrator.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 5ef7db0e..66813b8f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -295,6 +295,7 @@ class ExtractionOrchestrator: statement_entity_edges, entity_entity_edges, dialog_data_list, + dedup_details, ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -306,6 +307,11 @@ class ExtractionOrchestrator: dialog_data_list, ) + # 步骤 7: 同步用户别名到数据库表(仅正式模式) + if not is_pilot_run: + logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表") + await self._update_end_user_other_name(entity_nodes, dialog_data_list) + logger.info(f"知识提取流水线运行完成({mode_str})") return ( dialogue_nodes, @@ -1492,6 +1498,7 @@ class ExtractionOrchestrator: list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[DialogData], dict ]: """ @@ -1555,6 +1562,8 @@ class ExtractionOrchestrator: statement_chunk_edges, dedup_statement_entity_edges, dedup_entity_entity_edges, + dialog_data_list, + dedup_details, ) final_entity_nodes = dedup_entity_nodes @@ -1562,7 +1571,7 @@ class ExtractionOrchestrator: final_entity_entity_edges = dedup_entity_entity_edges else: # 正式模式:执行完整的两阶段去重 - result_tuple = await dedup_layers_and_merge_and_return( + dedup_result_tuple = await dedup_layers_and_merge_and_return( dialogue_nodes, chunk_nodes, statement_nodes, @@ -1578,19 +1587,31 @@ class ExtractionOrchestrator: # 解包返回值 ( - _, - _, - _, + dialogue_nodes, + chunk_nodes, + statement_nodes, final_entity_nodes, - _, + statement_chunk_edges, final_statement_entity_edges, final_entity_entity_edges, dedup_details, - ) = result_tuple + ) = dedup_result_tuple # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) + result_tuple = ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dialog_data_list, + dedup_details, + ) + logger.info( f"去重后: {len(final_entity_nodes)} 个实体节点, " f"{len(final_statement_entity_edges)} 条陈述句-实体边, " From f30260939a1750eed53d7a8090168e1c88f82488 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 12:20:03 +0800 Subject: [PATCH 062/120] feat: Add feature_billing and feature_user_management fields to tenant model --- api/app/controllers/user_controller.py | 69 +++++++++++++++++++++++++ api/app/repositories/user_repository.py | 30 +++++++---- api/app/schemas/user_schema.py | 3 +- api/app/services/tenant_service.py | 12 +++-- 4 files changed, 98 insertions(+), 16 deletions(-) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 16213690..20e2b974 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -111,6 +111,18 @@ def get_current_user_info( break api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") + + # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回全部权限 + if current_user.external_source: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() + if source and source.permissions: + result_schema.permissions = source.permissions + else: + result_schema.permissions = [] + else: + result_schema.permissions = ["pricing", "user"] + return success(data=result_schema, msg=t("users.info.get_success")) @@ -135,6 +147,63 @@ def get_tenant_superusers( return success(data=superusers_schema, msg=t("users.list.superusers_success")) +@router.get("/tenant/users", response_model=ApiResponse) +def get_tenant_users( + page: int = 1, + size: int = 20, + is_active: bool = None, + is_superuser: bool = None, + search: str = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) +): + """获取当前用户所在租户的用户列表(普通用户可访问)""" + api_logger.info(f"获取租户用户列表请求: tenant_id={current_user.tenant_id}, 操作者: {current_user.username}") + + if not current_user.tenant_id: + raise BusinessException("用户没有租户信息", code=BizCode.TENANT_NOT_FOUND) + + from app.services.tenant_service import TenantService + tenant_service = TenantService(db) + + skip = (page - 1) * size + users = tenant_service.get_tenant_users( + tenant_id=current_user.tenant_id, + skip=skip, + limit=size, + is_active=is_active, + is_superuser=is_superuser, + search=search + ) + total = tenant_service.count_tenant_users( + tenant_id=current_user.tenant_id, + is_active=is_active, + is_superuser=is_superuser, + search=search + ) + + users_schema = [user_schema.User.model_validate(u) for u in users] + for u_schema in users_schema: + user = users[[s.id for s in users_schema].index(u_schema.id)] + if user.external_source: + from premium.sso.models import SSOSource + source = db.query(SSOSource).filter(SSOSource.source_code == user.external_source).first() + u_schema.permissions = source.permissions if source and source.permissions else [] + else: + u_schema.permissions = ["pricing", "user"] + + return success( + data={ + "users": users_schema, + "total": total, + "page": page, + "size": size, + }, + msg=t("users.list.get_success") + ) + + @router.get("/{user_id}", response_model=ApiResponse) def get_user_info_by_id( diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index 3f8919aa..af4449e5 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -158,22 +158,26 @@ class UserRepository: raise def get_users_by_tenant( - self, - tenant_id: uuid.UUID, - skip: int = 0, + self, + tenant_id: uuid.UUID, + skip: int = 0, limit: int = 100, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> List[User]: """获取租户下的用户列表""" db_logger.debug(f"查询租户用户: tenant_id={tenant_id}") - + try: query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -181,7 +185,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + users = query.offset(skip).limit(limit).all() db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}") return users @@ -190,18 +194,22 @@ class UserRepository: raise def count_users_by_tenant( - self, + self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" try: query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - + if is_active is not None: query = query.filter(User.is_active == is_active) - + + if is_superuser is not None: + query = query.filter(User.is_superuser == is_superuser) + if search: query = query.filter( or_( @@ -209,7 +217,7 @@ class UserRepository: User.email.ilike(f"%{search}%") ) ) - + return query.scalar() except Exception as e: db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}") diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py index 6b880696..f307a5a3 100644 --- a/api/app/schemas/user_schema.py +++ b/api/app/schemas/user_schema.py @@ -1,6 +1,6 @@ from dataclasses import field from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict -from typing import Optional +from typing import Optional, List import datetime import uuid @@ -85,6 +85,7 @@ class User(UserBase): current_workspace_name: Optional[str] = None role: Optional[WorkspaceRole] = None preferred_language: Optional[str] = "zh" # 用户语言偏好 + permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制 # 将 datetime 转换为毫秒时间戳 @validator("created_at", pre=True) diff --git a/api/app/services/tenant_service.py b/api/app/services/tenant_service.py index 066edf57..b9c5800d 100644 --- a/api/app/services/tenant_service.py +++ b/api/app/services/tenant_service.py @@ -142,11 +142,12 @@ class TenantService: # 租户用户管理 def get_tenant_users( - self, - tenant_id: uuid.UUID, - skip: int = 0, + self, + tenant_id: uuid.UUID, + skip: int = 0, limit: int = 100, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> List[UserModel]: """获取租户下的用户列表""" @@ -155,19 +156,22 @@ class TenantService: skip=skip, limit=limit, is_active=is_active, + is_superuser=is_superuser, search=search ) def count_tenant_users( - self, + self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, search: Optional[str] = None ) -> int: """统计租户下的用户数量""" return self.user_repo.count_users_by_tenant( tenant_id=tenant_id, is_active=is_active, + is_superuser=is_superuser, search=search ) From bca43fcc75e41efa68f4fd997ee8acb0587c334d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 12:02:36 +0800 Subject: [PATCH 063/120] perf(workflow): expose extract_document_text as instance method, optimize knowledge base parallel search - Change extract_document_text from private to instance method in multimodal service for external access - Optimize knowledge base search logic to improve parallel retrieval performance --- .../workflow/nodes/document_extractor/node.py | 2 +- api/app/core/workflow/nodes/knowledge/node.py | 175 ++++++++++++------ .../core/workflow/utils/template_renderer.py | 2 +- api/app/services/multimodal_service.py | 6 +- 4 files changed, 120 insertions(+), 65 deletions(-) diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index 40641f3c..bd828760 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode): # Reuse cached bytes if already fetched if f.get_content(): file_input.set_content(f.get_content()) - text = await svc._extract_document_text(file_input) + text = await svc.extract_document_text(file_input) chunks.append(text) except Exception as e: logger.error( diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 92699cb4..d0b6d098 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -1,19 +1,23 @@ +import asyncio import logging import uuid from typing import Any +from langchain_core.documents import Document + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model, ModelType -from app.repositories import knowledge_repository, knowledgeshare_repository +from app.models import knowledge_model, ModelType +from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services.model_service import ModelConfigService @@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None - self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode): unique.append(doc) return unique - def _get_existing_kb_ids(self, db, kb_ids): + def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: """ - Resolve all accessible and valid knowledge base IDs for retrieval. - - This includes: - - Private knowledge bases owned by the user - - Shared knowledge bases - - Source knowledge bases mapped via knowledge sharing relationships - + Reorder the list of document blocks and return the top_k results most relevant to the query Args: - db: Database session. - kb_ids (list[UUID]): Knowledge base IDs from node configuration. + query: query string + docs: List of document chunk to be rearranged + top_k: The number of top-level documents returned Returns: - list[UUID]: Final list of valid knowledge base IDs. + Rearranged document chunk list (sorted in descending order of relevance) + + Raises: + ValueError: If the input document list is empty or top_k is invalid """ - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) - - existing_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) - - share_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids) + reranker = self.get_reranker_model() + # parameter validation + if not docs: + raise ValueError("retrieval chunks be empty") + if top_k <= 0: + raise ValueError("top_k must be a positive integer") + try: + # Convert to LangChain Document object + documents = [ + Document( + page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute + metadata=doc.metadata or {} # Deal with possible None metadata + ) + for doc in docs ] - items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters + + # Perform reordering (compress_documents will automatically handle relevance scores and indexing) + reranked_docs = list(reranker.compress_documents(documents, query)) + + # Sort in descending order based on relevance score + reranked_docs.sort( + key=lambda x: x.metadata.get("relevance_score", 0), + reverse=True ) - existing_ids.extend(items) - return existing_ids + # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] + result = [] + for item in reranked_docs[:top_k]: + for doc in docs: + if doc.page_content == item.page_content: + doc.metadata["score"] = item.metadata["relevance_score"] + result.append(doc) + return result + except Exception as e: + raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e def get_reranker_model(self) -> RedBearRerank: """ @@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + async def knowledge_retrieval(self, db, query, db_knowledge, kb_config): + rs = [] if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + tasks = [] for child in children: if not (child and child.chunk_num > 0 and child.status == 1): continue - kb_config.kb_id = child.id - self.knowledge_retrieval(db, query, rs, child, kb_config) - return - self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + child_kb_config = kb_config.model_copy() + child_kb_config.kb_id = child.id + tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) + return rs + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) indices = f"Vector_index_{kb_config.kb_id}_Node".lower() match kb_config.retrieve_type: case RetrieveType.PARTICIPLE: - rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + ) case RetrieveType.SEMANTIC: - rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + ) case RetrieveType.HYBRID: - rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) + rs1_task = asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + rs2_task = asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) if not unique_rs: - return + return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + rs.extend( + await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k} + ) + ) else: rs.extend(sorted( unique_rs, @@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode): )[:kb_config.top_k]) case _: raise RuntimeError("Unknown retrieval type") + return rs async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ @@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode): knowledge_bases = self.typed_config.knowledge_bases rs = [] + tasks = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") - self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) + tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) if not rs: return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + final_rs = await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k} + ) else: final_rs = sorted( rs, diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index bb1e18bf..2c2d0f67 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, conv_vars: dict[str, Any] | LazyVariableDict, - node_outputs: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 4cf3d89d..120cccb7 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -438,13 +438,13 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return True, { "type": "text", - "text": f"\n{await self._extract_document_text(file)}\n" + "text": f"\n{await self.extract_document_text(file)}\n" } else: # 本地文件,提取文本内容 server_url = settings.FILE_LOCAL_SERVER_URL file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" - text = await self._extract_document_text(file) + text = await self.extract_document_text(file) file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file.upload_file_id ).first() @@ -542,7 +542,7 @@ class MultimodalService: server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" - async def _extract_document_text(self, file: FileInput) -> str: + async def extract_document_text(self, file: FileInput) -> str: """ 提取文档文本内容 From 9730c5ce0f87f66e2f3147a8eeb7c836d585ba81 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Mar 2026 12:24:52 +0800 Subject: [PATCH 064/120] [changes] Construct the final return structure directly. --- .../extraction_orchestrator.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 66813b8f..f6a143cd 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -182,7 +182,7 @@ class ExtractionOrchestrator: list[StatementEntityEdge], list[EntityEntityEdge], list[PerceptualEdge], - dict + list[DialogData] ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -1571,7 +1571,16 @@ class ExtractionOrchestrator: final_entity_entity_edges = dedup_entity_entity_edges else: # 正式模式:执行完整的两阶段去重 - dedup_result_tuple = await dedup_layers_and_merge_and_return( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + final_entity_nodes, + statement_chunk_edges, + final_statement_entity_edges, + final_entity_entity_edges, + dedup_details, + ) = await dedup_layers_and_merge_and_return( dialogue_nodes, chunk_nodes, statement_nodes, @@ -1585,18 +1594,6 @@ class ExtractionOrchestrator: llm_client=self.llm_client, ) - # 解包返回值 - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - final_entity_nodes, - statement_chunk_edges, - final_statement_entity_edges, - final_entity_entity_edges, - dedup_details, - ) = dedup_result_tuple - # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) From 14838dc06400062691dcf9b61029343815fe4926 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 13:58:31 +0800 Subject: [PATCH 065/120] feat: Update user controller --- api/app/controllers/user_controller.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 20e2b974..e67a0b76 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -112,7 +112,7 @@ def get_current_user_info( api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") - # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回全部权限 + # 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限 if current_user.external_source: from premium.sso.models import SSOSource source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first() @@ -121,7 +121,8 @@ def get_current_user_info( else: result_schema.permissions = [] else: - result_schema.permissions = ["pricing", "user"] + from premium.sso.base import SSOPermission + result_schema.permissions = [SSOPermission.ALL.value] return success(data=result_schema, msg=t("users.info.get_success")) @@ -191,7 +192,8 @@ def get_tenant_users( source = db.query(SSOSource).filter(SSOSource.source_code == user.external_source).first() u_schema.permissions = source.permissions if source and source.permissions else [] else: - u_schema.permissions = ["pricing", "user"] + from premium.sso.base import SSOPermission + u_schema.permissions = [SSOPermission.ALL.value] return success( data={ From ee6b8ffa628f05cdfeb4136751607c0680f99fef Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 14:07:52 +0800 Subject: [PATCH 066/120] feat: Update user controller. --- api/app/services/tenant_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/app/services/tenant_service.py b/api/app/services/tenant_service.py index b9c5800d..2ff7ff51 100644 --- a/api/app/services/tenant_service.py +++ b/api/app/services/tenant_service.py @@ -142,9 +142,9 @@ class TenantService: # 租户用户管理 def get_tenant_users( - self, - tenant_id: uuid.UUID, - skip: int = 0, + self, + tenant_id: uuid.UUID, + skip: int = 0, limit: int = 100, is_active: Optional[bool] = None, is_superuser: Optional[bool] = None, @@ -161,7 +161,7 @@ class TenantService: ) def count_tenant_users( - self, + self, tenant_id: uuid.UUID, is_active: Optional[bool] = None, is_superuser: Optional[bool] = None, From d0ca5c8b276ae96535daf46b8a294cbb82c0410c Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 14:17:22 +0800 Subject: [PATCH 067/120] feat: Update user controller --- api/app/controllers/user_controller.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index e67a0b76..3626f169 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -121,8 +121,7 @@ def get_current_user_info( else: result_schema.permissions = [] else: - from premium.sso.base import SSOPermission - result_schema.permissions = [SSOPermission.ALL.value] + result_schema.permissions = ["all"] return success(data=result_schema, msg=t("users.info.get_success")) @@ -192,8 +191,7 @@ def get_tenant_users( source = db.query(SSOSource).filter(SSOSource.source_code == user.external_source).first() u_schema.permissions = source.permissions if source and source.permissions else [] else: - from premium.sso.base import SSOPermission - u_schema.permissions = [SSOPermission.ALL.value] + u_schema.permissions = ["all"] return success( data={ From 539999131caa6e5a97ac70421318509f77f2a459 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 14:26:46 +0800 Subject: [PATCH 068/120] feat: Update user controller --- api/app/controllers/user_controller.py | 58 ----------------------- api/app/repositories/tenant_repository.py | 24 +--------- api/app/services/tenant_service.py | 37 +-------------- 3 files changed, 2 insertions(+), 117 deletions(-) diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 3626f169..cc16a6b4 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -147,64 +147,6 @@ def get_tenant_superusers( return success(data=superusers_schema, msg=t("users.list.superusers_success")) -@router.get("/tenant/users", response_model=ApiResponse) -def get_tenant_users( - page: int = 1, - size: int = 20, - is_active: bool = None, - is_superuser: bool = None, - search: str = None, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - t: Callable = Depends(get_translator) -): - """获取当前用户所在租户的用户列表(普通用户可访问)""" - api_logger.info(f"获取租户用户列表请求: tenant_id={current_user.tenant_id}, 操作者: {current_user.username}") - - if not current_user.tenant_id: - raise BusinessException("用户没有租户信息", code=BizCode.TENANT_NOT_FOUND) - - from app.services.tenant_service import TenantService - tenant_service = TenantService(db) - - skip = (page - 1) * size - users = tenant_service.get_tenant_users( - tenant_id=current_user.tenant_id, - skip=skip, - limit=size, - is_active=is_active, - is_superuser=is_superuser, - search=search - ) - total = tenant_service.count_tenant_users( - tenant_id=current_user.tenant_id, - is_active=is_active, - is_superuser=is_superuser, - search=search - ) - - users_schema = [user_schema.User.model_validate(u) for u in users] - for u_schema in users_schema: - user = users[[s.id for s in users_schema].index(u_schema.id)] - if user.external_source: - from premium.sso.models import SSOSource - source = db.query(SSOSource).filter(SSOSource.source_code == user.external_source).first() - u_schema.permissions = source.permissions if source and source.permissions else [] - else: - u_schema.permissions = ["all"] - - return success( - data={ - "users": users_schema, - "total": total, - "page": page, - "size": size, - }, - msg=t("users.list.get_success") - ) - - - @router.get("/{user_id}", response_model=ApiResponse) def get_user_info_by_id( user_id: uuid.UUID, diff --git a/api/app/repositories/tenant_repository.py b/api/app/repositories/tenant_repository.py index 2934dda3..462c75e5 100644 --- a/api/app/repositories/tenant_repository.py +++ b/api/app/repositories/tenant_repository.py @@ -100,15 +100,6 @@ class TenantRepository: db_tenant.is_active = False return True - def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]: - """获取租户下的所有用户""" - query = self.db.query(User).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - return query.all() - def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]: """获取用户所属的租户""" user = self.db.query(User).filter(User.id == user_id).first() @@ -130,15 +121,6 @@ class TenantRepository: user.tenant_id = tenant_id self.db.flush() - return True - - def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int: - """统计租户下的用户数量""" - query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - return query.scalar() @@ -161,8 +143,4 @@ def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]: def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]: """获取用户所属的租户""" - return TenantRepository(db).get_user_tenant(user_id) - -def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]: - """获取租户下的所有用户""" - return TenantRepository(db).get_tenant_users(tenant_id) \ No newline at end of file + return TenantRepository(db).get_user_tenant(user_id) \ No newline at end of file diff --git a/api/app/services/tenant_service.py b/api/app/services/tenant_service.py index 2ff7ff51..369327ba 100644 --- a/api/app/services/tenant_service.py +++ b/api/app/services/tenant_service.py @@ -138,42 +138,7 @@ class TenantService: except Exception as e: business_logger.error(f"删除租户失败: {str(e)}") - raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR) - - # 租户用户管理 - def get_tenant_users( - self, - tenant_id: uuid.UUID, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - is_superuser: Optional[bool] = None, - search: Optional[str] = None - ) -> List[UserModel]: - """获取租户下的用户列表""" - return self.user_repo.get_users_by_tenant( - tenant_id=tenant_id, - skip=skip, - limit=limit, - is_active=is_active, - is_superuser=is_superuser, - search=search - ) - - def count_tenant_users( - self, - tenant_id: uuid.UUID, - is_active: Optional[bool] = None, - is_superuser: Optional[bool] = None, - search: Optional[str] = None - ) -> int: - """统计租户下的用户数量""" - return self.user_repo.count_users_by_tenant( - tenant_id=tenant_id, - is_active=is_active, - is_superuser=is_superuser, - search=search - ) + raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR) def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: """将用户分配给租户""" From 2597a1f5321220411b0a5d31b2791970f7a4a278 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 14:36:19 +0800 Subject: [PATCH 069/120] feat: Update user controller --- api/app/repositories/tenant_repository.py | 24 +++++++++++++++- api/app/services/tenant_service.py | 35 +++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/api/app/repositories/tenant_repository.py b/api/app/repositories/tenant_repository.py index 462c75e5..2934dda3 100644 --- a/api/app/repositories/tenant_repository.py +++ b/api/app/repositories/tenant_repository.py @@ -100,6 +100,15 @@ class TenantRepository: db_tenant.is_active = False return True + def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]: + """获取租户下的所有用户""" + query = self.db.query(User).filter(User.tenant_id == tenant_id) + + if is_active is not None: + query = query.filter(User.is_active == is_active) + + return query.all() + def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]: """获取用户所属的租户""" user = self.db.query(User).filter(User.id == user_id).first() @@ -121,6 +130,15 @@ class TenantRepository: user.tenant_id = tenant_id self.db.flush() + return True + + def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int: + """统计租户下的用户数量""" + query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) + + if is_active is not None: + query = query.filter(User.is_active == is_active) + return query.scalar() @@ -143,4 +161,8 @@ def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]: def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]: """获取用户所属的租户""" - return TenantRepository(db).get_user_tenant(user_id) \ No newline at end of file + return TenantRepository(db).get_user_tenant(user_id) + +def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]: + """获取租户下的所有用户""" + return TenantRepository(db).get_tenant_users(tenant_id) \ No newline at end of file diff --git a/api/app/services/tenant_service.py b/api/app/services/tenant_service.py index 369327ba..36205503 100644 --- a/api/app/services/tenant_service.py +++ b/api/app/services/tenant_service.py @@ -140,6 +140,41 @@ class TenantService: business_logger.error(f"删除租户失败: {str(e)}") raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR) + # 租户用户管理 + def get_tenant_users( + self, + tenant_id: uuid.UUID, + skip: int = 0, + limit: int = 100, + is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, + search: Optional[str] = None + ) -> List[UserModel]: + """获取租户下的用户列表""" + return self.user_repo.get_users_by_tenant( + tenant_id=tenant_id, + skip=skip, + limit=limit, + is_active=is_active, + is_superuser=is_superuser, + search=search + ) + + def count_tenant_users( + self, + tenant_id: uuid.UUID, + is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, + search: Optional[str] = None + ) -> int: + """统计租户下的用户数量""" + return self.user_repo.count_users_by_tenant( + tenant_id=tenant_id, + is_active=is_active, + is_superuser=is_superuser, + search=search + ) + def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: """将用户分配给租户""" # 检查租户是否存在 From 7fbf3e8873817f6120e573bb128cc34216943981 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 14:48:25 +0800 Subject: [PATCH 070/120] feat: Update user controller --- api/app/models/user_model.py | 5 ++++- api/app/schemas/user_schema.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index 81319789..c0b17d14 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -19,9 +19,12 @@ class User(Base): last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空 # SSO 外部关联字段 - external_id = Column(String(100), nullable=True) # 外部用户ID + external_id = Column(String(100), nullable=True) # 外部用户 ID external_source = Column(String(50), nullable=True) # 来源系统 + # 用户联系方式 + phone = Column(String(50), nullable=True) # 用户电话 + # 用户语言偏好 preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文 diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py index f307a5a3..aa9ac256 100644 --- a/api/app/schemas/user_schema.py +++ b/api/app/schemas/user_schema.py @@ -20,6 +20,7 @@ class UserCreate(UserBase): class UserUpdate(BaseModel): username: Optional[str] = None email: Optional[EmailStr] = None + phone: Optional[str] = None is_active: Optional[bool] = None is_superuser: Optional[bool] = None @@ -85,6 +86,7 @@ class User(UserBase): current_workspace_name: Optional[str] = None role: Optional[WorkspaceRole] = None preferred_language: Optional[str] = "zh" # 用户语言偏好 + phone: Optional[str] = None # 用户电话 permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制 # 将 datetime 转换为毫秒时间戳 From 758be0087f04e022bd1db6ffa9311708d30a5b82 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 27 Mar 2026 15:13:17 +0800 Subject: [PATCH 071/120] [add] migration script --- .../versions/4e89970f9e7c_202603271515.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 api/migrations/versions/4e89970f9e7c_202603271515.py diff --git a/api/migrations/versions/4e89970f9e7c_202603271515.py b/api/migrations/versions/4e89970f9e7c_202603271515.py new file mode 100644 index 00000000..f37c4b27 --- /dev/null +++ b/api/migrations/versions/4e89970f9e7c_202603271515.py @@ -0,0 +1,30 @@ +"""202603271515 + +Revision ID: 4e89970f9e7c +Revises: 6b8a461148ff +Create Date: 2026-03-27 15:12:27.518344 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4e89970f9e7c' +down_revision: Union[str, None] = '6b8a461148ff' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'phone') + # ### end Alembic commands ### From e659ca9fa295d50a0ccc54eade5ae24296756ad6 Mon Sep 17 00:00:00 2001 From: wxy Date: Fri, 27 Mar 2026 15:48:21 +0800 Subject: [PATCH 072/120] refactor(app): merge API Key search into search parameter --- api/app/controllers/app_controller.py | 37 +++++++++++++++------------ 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3ba9c3a9..352e0f0c 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -57,7 +57,6 @@ def list_apps( page: int = 1, pagesize: int = 10, ids: Optional[str] = None, - api_key: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -66,7 +65,7 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - - 当提供 api_key 参数时,查找该 API Key 关联的应用 + - search 参数支持:应用名称模糊搜索、API Key 精确搜索 """ from sqlalchemy import select as sa_select from app.models.api_key_model import ApiKey @@ -74,23 +73,29 @@ def list_apps( workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - # 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 - if api_key: - matched_id = db.execute( - sa_select(ApiKey.resource_id).where( - ApiKey.workspace_id == workspace_id, - ApiKey.api_key == api_key, - ApiKey.resource_id.isnot(None), - ) - ).scalar_one_or_none() - ids = str(matched_id) if matched_id else "" + # 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索 + if search: + search = search.strip() + # 尝试作为 API Key 精确匹配(API Key 通常较长) + if len(search) >= 10: + matched_id = db.execute( + sa_select(ApiKey.resource_id).where( + ApiKey.workspace_id == workspace_id, + ApiKey.api_key == search, + ApiKey.resource_id.isnot(None), + ) + ).scalar_one_or_none() + if matched_id: + # 找到 API Key,直接返回关联的应用 + ids = str(matched_id) - # 当 ids 存在且不为 None 时,根据 ids 获取应用 + # 当 ids 存在时,根据 ids 获取应用(不分页) if ids is not None: app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] - items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) - items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - return success(data=items) + if app_ids: + items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) + items = [service._convert_to_schema(app, workspace_id) for app in items_orm] + return success(data=items) # 正常分页查询 items_orm, total = app_service.list_apps( From bd70a8b8123e5c04fb091bfea6153d9390d9d683 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Fri, 27 Mar 2026 16:25:46 +0800 Subject: [PATCH 073/120] fix(app): localize validation messages and enhance error context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace English validation messages with Chinese localized strings - Update "model config" to "模型配置" - Update "memory config" to "记忆配置" - Enhance error message with detailed context about missing configurations - Add BizCode.CONFIG_MISSING error code for better error handling - Include missing_params in error context for debugging and client-side handling- --- api/app/services/app_service.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 4dcabff8..736049e5 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1753,12 +1753,16 @@ class AppService: miss_params = [] if agent_cfg.default_model_config_id is None: - miss_params.append("model config") + miss_params.append("模型配置") if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"): - miss_params.append("memory config") + miss_params.append("记忆配置") if miss_params: - raise BusinessException(f"{', '.join(miss_params)} is required") + raise BusinessException( + f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。", + BizCode.CONFIG_MISSING, + context={"missing_params": miss_params}, + ) config = { "system_prompt": agent_cfg.system_prompt, From 46fa99a8b85b4abeb45b22336442a6a3df715d8b Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Fri, 27 Mar 2026 16:27:09 +0800 Subject: [PATCH 074/120] fix(app): 1.Handling of large file upload issues; 2. Handling of abnormal display of conversation titles when the opening remarks function is enabled --- api/app/core/storage/oss.py | 101 +++++++++++++++++++---- api/app/models/conversation_model.py | 6 ++ api/app/schemas/app_schema.py | 2 +- api/app/services/app_chat_service.py | 8 +- api/app/services/app_service.py | 2 - api/app/services/conversation_service.py | 8 +- api/app/services/draft_run_service.py | 29 ++++--- 7 files changed, 121 insertions(+), 35 deletions(-) diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 1db86fef..c6c6ec48 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -44,6 +44,8 @@ class OSSStorage(StorageBackend): access_key_id: str, access_key_secret: str, bucket_name: str, + connect_timeout: int = 30, + multipart_threshold: int = 10 * 1024 * 1024, # 10MB ): """ Initialize the OSSStorage backend. @@ -53,6 +55,8 @@ class OSSStorage(StorageBackend): access_key_id: The Aliyun access key ID. access_key_secret: The Aliyun access key secret. bucket_name: The name of the OSS bucket. + connect_timeout: Connection timeout in seconds (default: 30). + multipart_threshold: File size threshold for multipart upload (default: 10MB). Raises: StorageConfigError: If any required configuration is missing. @@ -69,10 +73,17 @@ class OSSStorage(StorageBackend): self.endpoint = endpoint self.bucket_name = bucket_name + self.multipart_threshold = multipart_threshold try: auth = oss2.Auth(access_key_id, access_key_secret) - self.bucket = oss2.Bucket(auth, endpoint, bucket_name) + # 设置超时和重试 + self.bucket = oss2.Bucket( + auth, + endpoint, + bucket_name, + connect_timeout=connect_timeout + ) logger.info( f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}" ) @@ -108,21 +119,38 @@ class OSSStorage(StorageBackend): if content_type: headers["Content-Type"] = content_type - self.bucket.put_object(file_key, content, headers=headers if headers else None) + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB per part + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers if headers else None) + logger.info(f"File uploaded to OSS successfully: {file_key}") return file_key except OssError as e: logger.error(f"OSS error uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e.message}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to upload file to OSS: {e}", + message=f"Failed to upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -135,28 +163,73 @@ class OSSStorage(StorageBackend): ) -> int: """Upload from async stream to OSS. Returns total bytes written.""" buf = io.BytesIO() + headers = {"Content-Type": content_type} if content_type else None + upload_id = None + try: + # 收集流数据 + total_size = 0 async for chunk in stream: + if not chunk: + continue buf.write(chunk) + total_size += len(chunk) + content = buf.getvalue() - headers = {"Content-Type": content_type} if content_type else None - self.bucket.put_object(file_key, content, headers=headers) - logger.info(f"File stream uploaded to OSS successfully: {file_key}") - return len(content) + + if not content: + raise StorageUploadError( + message="Empty stream content", + file_key=file_key, + ) + + # 大文件使用分片上传 + if len(content) > self.multipart_threshold: + logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)") + upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id + parts = [] + part_size = 5 * 1024 * 1024 # 5MB + part_num = 1 + + for offset in range(0, len(content), part_size): + chunk = content[offset:offset + part_size] + result = self.bucket.upload_part(file_key, upload_id, part_num, chunk) + parts.append(oss2.models.PartInfo(part_num, result.etag)) + part_num += 1 + + self.bucket.complete_multipart_upload(file_key, upload_id, parts) + else: + self.bucket.put_object(file_key, content, headers=headers) + + logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)") + return total_size + except OssError as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"OSS error stream uploading file {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e.message}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: + if upload_id: + try: + self.bucket.abort_multipart_upload(file_key, upload_id) + except: + pass logger.error(f"Failed to stream upload file to OSS {file_key}: {e}") raise StorageUploadError( - message=f"Failed to stream upload file to OSS: {e}", + message=f"Failed to stream upload file to OSS: {str(e)}", file_key=file_key, cause=e, ) + finally: + buf.close() async def download(self, file_key: str) -> bytes: """ @@ -182,14 +255,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error downloading file {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e.message}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to download file from OSS {file_key}: {e}") raise StorageDownloadError( - message=f"Failed to download file from OSS: {e}", + message=f"Failed to download file from OSS: {str(e)}", file_key=file_key, cause=e, ) @@ -215,14 +288,14 @@ class OSSStorage(StorageBackend): except OssError as e: logger.error(f"OSS error deleting file {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e.message}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) except Exception as e: logger.error(f"Failed to delete file from OSS {file_key}: {e}") raise StorageDeleteError( - message=f"Failed to delete file from OSS: {e}", + message=f"Failed to delete file from OSS: {str(e)}", file_key=file_key, cause=e, ) diff --git a/api/app/models/conversation_model.py b/api/app/models/conversation_model.py index 4011247f..4ae9034d 100644 --- a/api/app/models/conversation_model.py +++ b/api/app/models/conversation_model.py @@ -57,6 +57,12 @@ class Conversation(Base): workspace = relationship("Workspace") messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") + @property + def is_first_user_message(self): + """判断当前是否是用户的第一条消息(无视开场白)""" + user_message_count = sum(1 for msg in self.messages if msg.role == "user") + return user_message_count == 1 + class ConversationDetail(Base): __tablename__ = "conversation_details" diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index e34945eb..f1e9132f 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel): # 记忆配置 memory: MemoryConfig = Field( - default_factory=lambda: MemoryConfig(enabled=True), + default_factory=lambda: MemoryConfig(enabled=False), description="对话历史记忆配置" ) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..bdccd787 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -140,13 +140,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( @@ -367,13 +367,13 @@ class AppChatService: # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 is_new_conversation = len(history) == 0 if is_new_conversation: - opening = self.agent_service._get_opening_statement(features_config, True, variables) + opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables) if opening: self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=opening, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) # 重新加载历史(包含刚写入的开场白) history = await self.conversation_service.get_conversation_history( diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 4dcabff8..4d2aa1c5 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1084,7 +1084,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned exists = self.db.query( @@ -1096,7 +1095,6 @@ class AppService: if not exists: cleaned["memory_config_id"] = None cleaned.pop("memory_content", None) - cleaned["enabled"] = False return cleaned diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 014d96b7..ecf316d9 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -214,14 +214,14 @@ class ConversationService: conversation.message_count += 1 - if conversation.message_count == 1 and role == "user": + self.db.commit() + self.db.refresh(message) + + if conversation.is_first_user_message and role == "user": conversation.title = ( content[:50] + ("..." if len(content) > 50 else "") ) - self.db.commit() - self.db.refresh(message) - logger.info( "Message added successfully", extra={ diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index e188872f..c658cf93 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -449,15 +449,16 @@ class AgentRunService: features_config: Dict[str, Any], is_new_conversation: bool, variables: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + ) -> tuple[Any, Any]: """首轮对话时返回开场白文本(支持变量替换),否则返回 None""" if not is_new_conversation: - return None + return None, None opening = features_config.get("opening_statement", {}) if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): - return None + return None, None statement = opening["statement"] + suggested_questions = opening["suggested_questions"] # 如果有变量,进行替换(仅支持 {{var_name}} 格式) if variables: @@ -465,7 +466,7 @@ class AgentRunService: placeholder = f"{{{{{var_name}}}}}" statement = statement.replace(placeholder, str(var_value)) - return statement + return statement, suggested_questions @staticmethod def _filter_citations( @@ -599,13 +600,16 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -845,14 +849,17 @@ class AgentRunService: # 5. 处理会话ID(创建或验证),新会话时写入开场白 is_new_conversation = not conversation_id - opening = self._get_opening_statement(features_config, is_new_conversation, variables) + opening, suggested_questions = None, None + if not sub_agent: + opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables) conversation_id = await self._ensure_conversation( conversation_id=conversation_id, app_id=agent_config.app_id, workspace_id=workspace_id, user_id=user_id, sub_agent=sub_agent, - opening_statement=opening + opening_statement=opening, + suggested_questions=suggested_questions ) model_info = ModelInfo( @@ -1061,7 +1068,8 @@ class AgentRunService: workspace_id: uuid.UUID, user_id: Optional[str], sub_agent: bool = False, - opening_statement: Optional[str] = None + opening_statement: Optional[str] = None, + suggested_questions: Optional[List[str]] = None ) -> str: """确保会话存在(创建或验证) @@ -1072,6 +1080,7 @@ class AgentRunService: user_id: 用户ID sub_agent: 是否为子代理 opening_statement: 开场白(新会话时作为第一条消息写入) + suggested_questions: 预设问题列表 Returns: str: 会话ID @@ -1115,7 +1124,7 @@ class AgentRunService: conversation_id=uuid.UUID(new_conv_id), role="assistant", content=opening_statement, - meta_data={} + meta_data={"suggested_questions": suggested_questions} ) logger.debug(f"已保存开场白到会话 {new_conv_id}") From 4e9b5736b151cd2947d1fe1a10f424315247fa2d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Fri, 27 Mar 2026 15:35:47 +0800 Subject: [PATCH 075/120] feat(cache): Add thread-safe Redis client and enable activity stats cache - Add get_thread_safe_redis() function with thread-local storage and PID checking to prevent "Future attached to a different loop" errors in Celery thread and prefork pools - Implement health_check_interval=30 to prevent stale connection errors after fork - Uncomment and enable ActivityStatsCache module in cache/memory/__init__.py - Uncomment ActivityStatsCache implementation in activity_stats_cache.py and update to use get_thread_safe_redis() - Update interest_memory.py to use thread-safe Redis client - Update write_tools.py to use thread-safe Redis client - Remove redundant Chinese comments from aioRedis.py for cleaner code - Ensures safe Redis operations across different execution contexts and Celery worker configurations --- api/app/aioRedis.py | 39 +++- api/app/cache/memory/__init__.py | 4 +- api/app/cache/memory/activity_stats_cache.py | 210 +++++++++--------- api/app/cache/memory/interest_memory.py | 8 +- .../core/memory/agent/utils/write_tools.py | 38 ++-- 5 files changed, 167 insertions(+), 132 deletions(-) diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index aac2aa84..f79ef0e1 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,6 +1,8 @@ import asyncio import json import logging +import os +import threading from typing import Dict, Any, Optional import redis.asyncio as redis @@ -21,6 +23,41 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) +_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}" + +# Thread-local storage for connection pools. +# Each thread (and each forked process) gets its own pool to avoid +# "Future attached to a different loop" errors in Celery --pool=threads +# and stale connections after fork in --pool=prefork. +_thread_local = threading.local() + + +def get_thread_safe_redis() -> redis.StrictRedis: + """Get a Redis client safe for the current execution context. + + Uses thread-local storage with PID checking to ensure: + - Each thread gets its own ConnectionPool (Celery --pool=threads) + - Pools are recreated after fork (Celery --pool=prefork) + - health_check_interval prevents stale connection errors + + Returns: + redis.StrictRedis: A Redis client with a thread/process-local pool. + """ + current_pid = os.getpid() + + if not hasattr(_thread_local, "pool") or getattr(_thread_local, "pid", None) != current_pid: + _thread_local.pid = current_pid + _thread_local.pool = ConnectionPool.from_url( + _REDIS_URL, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=5, + health_check_interval=30, + ) + + return redis.StrictRedis(connection_pool=_thread_local.pool) + async def get_redis_connection(): """获取Redis连接""" @@ -44,10 +81,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None): val = json.dumps(val, ensure_ascii=False) if expire is not None: - # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) else: - # 设置永久键值 await aio_redis.set(key, val) except Exception as e: logger.error(f"Redis set错误: {str(e)}") diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index fa9ad1b1..551062ac 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -4,9 +4,9 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 """ from .interest_memory import InterestMemoryCache -# from .activity_stats_cache import ActivityStatsCache +from .activity_stats_cache import ActivityStatsCache __all__ = [ "InterestMemoryCache", - # "ActivityStatsCache", + "ActivityStatsCache", ] diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py index 35c702b1..e0008353 100644 --- a/api/app/cache/memory/activity_stats_cache.py +++ b/api/app/cache/memory/activity_stats_cache.py @@ -1,124 +1,124 @@ -# """ -# Recent Activity Stats Cache +""" +Recent Activity Stats Cache -# 记忆提取活动统计缓存模块 -# 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 -# 查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 -# """ -# import json -# import logging -# from typing import Optional, Dict, Any -# from datetime import datetime +记忆提取活动统计缓存模块 +用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 +查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime -# from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis -# logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -# # 缓存过期时间:24小时 -# ACTIVITY_STATS_CACHE_EXPIRE = 86400 +# 缓存过期时间:24小时 +ACTIVITY_STATS_CACHE_EXPIRE = 86400 -# class ActivityStatsCache: -# """记忆提取活动统计缓存类""" +class ActivityStatsCache: + """记忆提取活动统计缓存类""" -# PREFIX = "cache:memory:activity_stats" + PREFIX = "cache:memory:activity_stats" -# @classmethod -# def _get_key(cls, workspace_id: str) -> str: -# """生成 Redis key + @classmethod + def _get_key(cls, workspace_id: str) -> str: + """生成 Redis key -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 完整的 Redis key -# """ -# return f"{cls.PREFIX}:by_workspace:{workspace_id}" + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_workspace:{workspace_id}" -# @classmethod -# async def set_activity_stats( -# cls, -# workspace_id: str, -# stats: Dict[str, Any], -# expire: int = ACTIVITY_STATS_CACHE_EXPIRE, -# ) -> bool: -# """设置记忆提取活动统计缓存 + @classmethod + async def set_activity_stats( + cls, + workspace_id: str, + stats: Dict[str, Any], + expire: int = ACTIVITY_STATS_CACHE_EXPIRE, + ) -> bool: + """设置记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID -# stats: 统计数据,格式: -# { -# "chunk_count": int, -# "statements_count": int, -# "triplet_entities_count": int, -# "triplet_relations_count": int, -# "temporal_count": int, -# } -# expire: 过期时间(秒),默认24小时 + Args: + workspace_id: 工作空间ID + stats: 统计数据,格式: + { + "chunk_count": int, + "statements_count": int, + "triplet_entities_count": int, + "triplet_relations_count": int, + "temporal_count": int, + } + expire: 过期时间(秒),默认24小时 -# Returns: -# 是否设置成功 -# """ -# try: -# key = cls._get_key(workspace_id) -# payload = { -# "stats": stats, -# "generated_at": datetime.now().isoformat(), -# "workspace_id": workspace_id, -# "cached": True, -# } -# value = json.dumps(payload, ensure_ascii=False) -# await aio_redis.set(key, value, ex=expire) -# logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") -# return True -# except Exception as e: -# logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) -# return False + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(workspace_id) + payload = { + "stats": stats, + "generated_at": datetime.now().isoformat(), + "workspace_id": workspace_id, + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await get_thread_safe_redis().set(key, value, ex=expire) + logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) + return False -# @classmethod -# async def get_activity_stats( -# cls, -# workspace_id: str, -# ) -> Optional[Dict[str, Any]]: -# """获取记忆提取活动统计缓存 + @classmethod + async def get_activity_stats( + cls, + workspace_id: str, + ) -> Optional[Dict[str, Any]]: + """获取记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 统计数据字典,缓存不存在或已过期返回 None -# """ -# try: -# key = cls._get_key(workspace_id) -# value = await aio_redis.get(key) -# if value: -# payload = json.loads(value) -# logger.info(f"命中活动统计缓存: {key}") -# return payload -# logger.info(f"活动统计缓存不存在或已过期: {key}") -# return None -# except Exception as e: -# logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) -# return None + Returns: + 统计数据字典,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(workspace_id) + value = await get_thread_safe_redis().get(key) + if value: + payload = json.loads(value) + logger.info(f"命中活动统计缓存: {key}") + return payload + logger.info(f"活动统计缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) + return None -# @classmethod -# async def delete_activity_stats( -# cls, -# workspace_id: str, -# ) -> bool: -# """删除记忆提取活动统计缓存 + @classmethod + async def delete_activity_stats( + cls, + workspace_id: str, + ) -> bool: + """删除记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 是否删除成功 -# """ -# try: -# key = cls._get_key(workspace_id) -# result = await aio_redis.delete(key) -# logger.info(f"删除活动统计缓存: {key}, 结果: {result}") -# return result > 0 -# except Exception as e: -# logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) -# return False + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(workspace_id) + result = await get_thread_safe_redis().delete(key) + logger.info(f"删除活动统计缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py index 108e2a37..2881f06c 100644 --- a/api/app/cache/memory/interest_memory.py +++ b/api/app/cache/memory/interest_memory.py @@ -9,7 +9,7 @@ import logging from typing import Optional, List, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class InterestMemoryCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -86,7 +86,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中兴趣分布缓存: {key}") @@ -114,7 +114,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index c01a36d1..55bcb8ba 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -260,24 +260,24 @@ async def write( with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - # # 将提取统计写入 Redis,按 workspace_id 存储 - # try: - # from app.cache.memory.activity_stats_cache import ActivityStatsCache + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache - # stats_to_cache = { - # "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, - # "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, - # "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, - # "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, - # "temporal_count": 0, - # } - # await ActivityStatsCache.set_activity_stats( - # workspace_id=str(memory_config.workspace_id), - # stats=stats_to_cache, - # ) - # logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") - # except Exception as cache_err: - # logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + stats_to_cache = { + "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, + "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, + "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, + "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, + "temporal_count": 0, + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) - # logger.info("=== Pipeline Complete ===") - # logger.info(f"Total execution time: {total_time:.2f} seconds") + logger.info("=== Pipeline Complete ===") + logger.info(f"Total execution time: {total_time:.2f} seconds") From 289b1989e5f149ccc3200d562d1f5ab744321e1a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Mar 2026 19:13:38 +0800 Subject: [PATCH 076/120] [changes] Semantic pruning enables the file to pass through --- .../data_preprocessing/data_pruning.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 967f529e..223345b4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -128,9 +128,15 @@ class SemanticPruner: 1. 空消息 2. 场景特定填充词库精确匹配 3. 常见寒暄精确匹配 - 4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了") + 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了") 5. 纯表情/标点 + + 注意:如果消息包含文件(files 字段非空),则不视为填充消息,予以保留。 """ + # 保护带有文件的消息:文件包含感知记忆信息,不应被删除 + if message.files and len(message.files) > 0: + return False + t = message.msg.strip() if not t: return True @@ -482,6 +488,12 @@ class SemanticPruner: """ to_delete_ids: set = set() for m in msgs: + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + has_files = m.files and len(m.files) > 0 + if has_files: + self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}") + continue + # 填充检测优先:先判断是否为填充,再看 LLM 保护 if self._is_filler_message(m): to_delete_ids.add(id(m)) @@ -549,6 +561,12 @@ class SemanticPruner: to_delete_ids: set = set() for m in msgs: msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 + has_files = m.files and len(m.files) > 0 + if has_files: + self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}") + continue # 第一优先级:填充消息无论模式直接删除,不参与后续场景判断 if self._is_filler_message(m): @@ -759,6 +777,11 @@ class SemanticPruner: msgs = dd.context.msgs original_count = len(msgs) total_original_msgs += original_count + + # 统计带文件的消息数量 + files_msg_count = sum(1 for m in msgs if m.files and len(m.files) > 0) + if files_msg_count > 0: + self._log(f"[剪枝-对话{d_idx+1}] 检测到 {files_msg_count}/{original_count} 条消息带有文件,将予以保护") # 相关对话:根据阶段决定处理力度 if extraction.is_related: @@ -801,6 +824,13 @@ class SemanticPruner: for idx, m in enumerate(msgs): msg_text = m.msg.strip() + + # 最高优先级保护:带有文件的消息一律保留,不参与分类 + has_files = m.files and len(m.files) > 0 + if has_files: + self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}") + llm_protected_msgs.append((idx, m)) # 放入保护列表 + continue if self._msg_matches_tokens(m, preserve_tokens): llm_protected_msgs.append((idx, m)) From f485398768af135ba32860ff60444b54775d6471 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Fri, 27 Mar 2026 19:13:51 +0800 Subject: [PATCH 077/120] fix(workflow): Parsing of DOC files --- api/app/models/conversation_model.py | 6 -- api/app/services/conversation_service.py | 8 +-- api/app/services/multimodal_service.py | 85 ++++++++++++++++++------ 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/api/app/models/conversation_model.py b/api/app/models/conversation_model.py index 4ae9034d..4011247f 100644 --- a/api/app/models/conversation_model.py +++ b/api/app/models/conversation_model.py @@ -57,12 +57,6 @@ class Conversation(Base): workspace = relationship("Workspace") messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") - @property - def is_first_user_message(self): - """判断当前是否是用户的第一条消息(无视开场白)""" - user_message_count = sum(1 for msg in self.messages if msg.role == "user") - return user_message_count == 1 - class ConversationDetail(Base): __tablename__ = "conversation_details" diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index ecf316d9..bd7f7496 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -214,14 +214,14 @@ class ConversationService: conversation.message_count += 1 - self.db.commit() - self.db.refresh(message) - - if conversation.is_first_user_message and role == "user": + if conversation.message_count <= 2 and role == "user": conversation.title = ( content[:50] + ("..." if len(content) > 50 else "") ) + self.db.commit() + self.db.refresh(message) + logger.info( "Message added successfully", extra={ diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 4cf3d89d..f854e987 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -12,6 +12,9 @@ import base64 import csv import io import json +import re +import olefile +import struct import zipfile from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -602,31 +605,75 @@ class MultimodalService: try: word_file = io.BytesIO(file_content) doc = Document(word_file) - return '\n'.join(p.text for p in doc.paragraphs) + text_lines = [] + for p in doc.paragraphs: + text = p.text.strip() + if text: + text_lines.append(text) + + for table in doc.tables: + for row in table.rows: + for cell in row.cells: + text = cell.text.strip() + if text: + text_lines.append(text) + + full_text = "\n".join(text_lines) + return full_text.strip() or "[docx 文件无文本内容]" except Exception as e: - logger.error(f"提取 docx 文本失败: {e}") + logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True) return f"[docx 提取失败: {str(e)}]" - # 旧版 .doc(OLE2 格式) + # 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table try: - import olefile ole = olefile.OleFileIO(io.BytesIO(file_content)) - if not ole.exists('WordDocument'): - return "[doc 提取失败: 未找到 WordDocument 流]" - # 读取 WordDocument 流,提取可见 ASCII/Unicode 文本 - stream = ole.openstream('WordDocument').read() - # Word Binary Format: 文本在流中以 UTF-16-LE 编码存储 - # 简单提取:过滤出可打印字符段 - try: - text = stream.decode('utf-16-le', errors='ignore') - except Exception: - text = stream.decode('latin-1', errors='ignore') - # 过滤控制字符,保留可打印内容 - import re - text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) - text = re.sub(r' +', ' ', text).strip() + word_stream = ole.openstream('WordDocument').read() + + # FIB offset 0xA bit9 决定使用 0Table 还是 1Table + fib_flags = struct.unpack_from(' Date: Fri, 27 Mar 2026 19:25:17 +0800 Subject: [PATCH 078/120] [changes] Semantic pruning enables the file to pass through --- .../data_preprocessing/data_pruning.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 223345b4..5390197a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene logger = logging.getLogger(__name__) +def message_has_files(message: "ConversationMessage") -> bool: + """检查消息是否包含文件。 + + Args: + message: 待检查的消息对象 + + Returns: + bool: 如果消息包含文件则返回 True,否则返回 False + """ + return message.files and len(message.files) > 0 + + class DialogExtractionResponse(BaseModel): """对话级一次性抽取的结构化返回,用于加速剪枝。 @@ -130,13 +142,7 @@ class SemanticPruner: 3. 常见寒暄精确匹配 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了") 5. 纯表情/标点 - - 注意:如果消息包含文件(files 字段非空),则不视为填充消息,予以保留。 """ - # 保护带有文件的消息:文件包含感知记忆信息,不应被删除 - if message.files and len(message.files) > 0: - return False - t = message.msg.strip() if not t: return True @@ -489,8 +495,7 @@ class SemanticPruner: to_delete_ids: set = set() for m in msgs: # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 - has_files = m.files and len(m.files) > 0 - if has_files: + if message_has_files(m): self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}") continue @@ -563,8 +568,7 @@ class SemanticPruner: msg_text = m.msg.strip() # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 - has_files = m.files and len(m.files) > 0 - if has_files: + if message_has_files(m): self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}") continue @@ -777,11 +781,6 @@ class SemanticPruner: msgs = dd.context.msgs original_count = len(msgs) total_original_msgs += original_count - - # 统计带文件的消息数量 - files_msg_count = sum(1 for m in msgs if m.files and len(m.files) > 0) - if files_msg_count > 0: - self._log(f"[剪枝-对话{d_idx+1}] 检测到 {files_msg_count}/{original_count} 条消息带有文件,将予以保护") # 相关对话:根据阶段决定处理力度 if extraction.is_related: @@ -826,8 +825,7 @@ class SemanticPruner: msg_text = m.msg.strip() # 最高优先级保护:带有文件的消息一律保留,不参与分类 - has_files = m.files and len(m.files) > 0 - if has_files: + if message_has_files(m): self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}") llm_protected_msgs.append((idx, m)) # 放入保护列表 continue From b699b746a5a6f26d696e1159ae6740ac463b7e6b Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 10:17:27 +0800 Subject: [PATCH 079/120] fix(web): log --- web/src/components/SiderMenu/index.tsx | 3 ++- web/src/views/ApplicationConfig/Logs.tsx | 10 +++++----- web/src/views/InviteRegister/index.tsx | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/web/src/components/SiderMenu/index.tsx b/web/src/components/SiderMenu/index.tsx index 3bd0cea3..21f7fd36 100644 --- a/web/src/components/SiderMenu/index.tsx +++ b/web/src/components/SiderMenu/index.tsx @@ -128,6 +128,7 @@ const Menu: FC<{ /** Filter menus based on user role and source */ useEffect(() => { + if (!user) return let menuList: MenuItem[] = [] if (user.role === 'member' && source === 'space') { @@ -136,7 +137,7 @@ const Menu: FC<{ menuList = allMenus[source] || [] } - const noAuthList = ['user', 'pricing'].filter(vo => !user.permissions?.includes(vo) && !user.permissions?.includes('all')) + const noAuthList = ['user', 'pricing'].filter(vo => (Array.isArray(user.permissions) && !user.permissions?.includes(vo) && !user.permissions?.includes('all')) || !Array.isArray(user.permissions)) if (noAuthList && !noAuthList?.includes('all')) { const filterMenus = (list: MenuItem[]): MenuItem[] =>{ diff --git a/web/src/views/ApplicationConfig/Logs.tsx b/web/src/views/ApplicationConfig/Logs.tsx index 88fa2607..49a5bbd6 100644 --- a/web/src/views/ApplicationConfig/Logs.tsx +++ b/web/src/views/ApplicationConfig/Logs.tsx @@ -34,16 +34,16 @@ const Statistics: FC = () => { className: 'rb:text-[#212332]' }, { - title: t('user.createTime'), + title: t('application.createTime'), dataIndex: 'created_at', key: 'created_at', render: (createdAt: string) => formatDateTime(createdAt, 'YYYY-MM-DD HH:mm:ss'), }, { - title: t('user.lastLoginTime'), - dataIndex: 'last_login_at', - key: 'last_login_at', - render: (lastLoginAt: string) => lastLoginAt ? formatDateTime(lastLoginAt, 'YYYY-MM-DD HH:mm:ss') : '-', + title: t('common.updated_at'), + dataIndex: 'updated_at', + key: 'updated_at', + render: (updatedAt: string) => updatedAt ? formatDateTime(updatedAt, 'YYYY-MM-DD HH:mm:ss') : '-', }, { title: t('common.operation'), diff --git a/web/src/views/InviteRegister/index.tsx b/web/src/views/InviteRegister/index.tsx index 42cffff1..72ae55e5 100644 --- a/web/src/views/InviteRegister/index.tsx +++ b/web/src/views/InviteRegister/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:37:12 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-04 10:05:39 + * @Last Modified time: 2026-03-27 22:22:18 */ /** * Invite Register Page @@ -144,7 +144,7 @@ const InviteRegister: React.FC = () => { }).then((res) => { const response = res as LoginInfo; updateLoginInfo(response); - navigate('/'); + navigate('/', { replace: true }); }).finally(() => { setLoading(false); }); From 8f216db35343dc1f4f500affc8b565d84980b300 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 11:35:09 +0800 Subject: [PATCH 080/120] [fix] Remove the limit on the number of output items. --- api/app/services/memory_forget_service.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 11118571..5122ae02 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -203,8 +203,7 @@ class MemoryForgetService: connector: Neo4jConnector, end_user_id: str, forgetting_threshold: float, - min_days_since_access: int, - limit: int = 20 + min_days_since_access: int ) -> list[Dict[str, Any]]: """ 获取待遗忘节点列表 @@ -216,7 +215,6 @@ class MemoryForgetService: end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 - limit: 返回节点数量限制 Returns: list: 待遗忘节点列表 @@ -247,14 +245,12 @@ class MemoryForgetService: n.activation_value as activation_value, n.last_access_time as last_access_time ORDER BY n.activation_value ASC - LIMIT $limit """ params = { 'end_user_id': end_user_id, 'threshold': forgetting_threshold, - 'min_access_time_str': min_access_time_str, - 'limit': limit + 'min_access_time_str': min_access_time_str } results = await connector.execute_query(query, **params) @@ -636,7 +632,7 @@ class MemoryForgetService: api_logger.error(f"获取历史趋势数据失败: {str(e)}") # 失败时返回空列表,不影响主流程 - # 获取待遗忘节点列表(前20个满足遗忘条件的节点) + # 获取待遗忘节点列表 pending_nodes = [] try: if end_user_id: @@ -652,8 +648,7 @@ class MemoryForgetService: connector=connector, end_user_id=end_user_id, forgetting_threshold=forgetting_threshold, - min_days_since_access=int(min_days), - limit=20 + min_days_since_access=int(min_days) ) api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") From 13352178ad15849daac6cff4d9b0a8ef2ae9a319 Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 30 Mar 2026 11:55:21 +0800 Subject: [PATCH 081/120] fix: standardize app list pagination and fix session log isolation --- api/app/controllers/app_controller.py | 7 ++++++- api/app/controllers/app_log_controller.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 352e0f0c..74991bcf 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -95,7 +95,12 @@ def list_apps( if app_ids: items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - return success(data=items) + # 返回标准分页格式 + meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False) + return success(data=PageData(page=meta, items=items)) + # ids 为空时,返回空列表 + meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False) + return success(data=PageData(page=meta, items=[])) # 正常分页查询 items_orm, total = app_service.list_apps( diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index dfd10644..ac0b2ac4 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -35,6 +35,7 @@ def list_app_logs( - 支持按 user_id 筛选 - 支持按 is_draft 筛选(草稿会话 / 发布会话) - 按最新更新时间倒序排列 + - 所有人(包括共享者和被共享者)都只能查看自己的会话记录 """ workspace_id = current_user.current_workspace_id @@ -47,6 +48,9 @@ def list_app_logs( Conversation.workspace_id == workspace_id, Conversation.is_active.is_(True), ) + + # 所有人只能查看自己的会话记录 + stmt = stmt.where(Conversation.user_id == str(current_user.id)) if user_id: stmt = stmt.where(Conversation.user_id == user_id) @@ -86,6 +90,7 @@ def get_app_log_detail( - 返回会话基本信息 + 所有消息(按时间正序) - 消息 meta_data 包含模型名、token 用量等信息 + - 所有人(包括共享者和被共享者)都只能查看自己的会话详情 """ workspace_id = current_user.current_workspace_id @@ -100,6 +105,7 @@ def get_app_log_detail( Conversation.app_id == app_id, Conversation.workspace_id == workspace_id, Conversation.is_active.is_(True), + Conversation.user_id == str(current_user.id), ) ).first() From 3aed5c447a53de8e678e1b844a05b9e3f6cc4341 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 13:36:02 +0800 Subject: [PATCH 082/120] fix(web): forget memory's pending nodes support page --- web/src/api/memory.ts | 2 ++ .../ApplicationConfig/components/ConfigHeader.tsx | 10 +++++----- .../components/FeaturesConfig/index.tsx | 2 +- web/src/views/Prompt/pages/History.tsx | 6 +++--- web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx | 11 +++++------ 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 1ec2d7dc..4467b649 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => { export const getForgetStats = (end_user_id: string) => { return request.get(`/memory/forget-memory/stats`, { end_user_id }) } +// 获取带遗忘节点列表 +export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes' // Implicit Memory - Preferences export const getImplicitPreferences = (end_user_id: string) => { return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index 8e6fc875..bebf6ebd 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -220,31 +220,31 @@ const ConfigHeader: FC = ({ />
diff --git a/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx b/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx index dba03ab2..3fb7bc93 100644 --- a/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx +++ b/web/src/views/ApplicationConfig/components/FeaturesConfig/index.tsx @@ -49,7 +49,7 @@ const FeaturesConfig: FC = ({ ?
diff --git a/web/src/views/Prompt/pages/History.tsx b/web/src/views/Prompt/pages/History.tsx index 573b4a90..19c033ed 100644 --- a/web/src/views/Prompt/pages/History.tsx +++ b/web/src/views/Prompt/pages/History.tsx @@ -116,13 +116,13 @@ const History: React.FC = () => {
{formatDateTime(item.created_at, 'YYYY/MM/DD HH:mm')}
-
handleClick('detail', item)} >
-
handleClick('edit', item)} >
-
handleClick('delete', item)} >
diff --git a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx index 2510aaa9..04391107 100644 --- a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx @@ -12,6 +12,7 @@ import { Row, Col, Progress, App, Table } from 'antd' import RbCard from '@/components/RbCard/Card' import { getForgetStats, + getForgetPendingNodesUrl, } from '@/api/memory' import type { ForgetData } from '../types' import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard' @@ -19,6 +20,7 @@ import RecentTrendsLineCard from '../components/RecentTrendsLineCard' import { formatDateTime } from '@/utils/format' import StatusTag from '@/components/StatusTag' import ForgetRefreshModal from '../components/ForgetRefreshModal'; +import RbTable from '@/components/Table' /** Maps node type keys to StatusTag colour presets for the pending-nodes table. */ const statusTagColors: Record = { @@ -191,7 +193,9 @@ const ForgetDetail = forwardRef((_props, ref) => { bodyClassName="rb:p-3! rb:py-0! rb:h-[calc(100%-54px)]" className="rb:h-full!" > - { render: (activation_value) => {activation_value} }, ]} - pagination={{ - pageSize: 5, - showQuickJumper: true, - className: 'rb:mt-5! rb:mb-5.75!' - }} className="table-header-has-bg" /> From 7acb7045f081046de2f2d381c26eb833062ff39c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 11:47:58 +0800 Subject: [PATCH 083/120] feat(agent, memory): add agent-perceived memory writing --- .../controllers/public_share_controller.py | 80 -------- api/app/core/agent/langchain_agent.py | 90 ++------- .../langgraph_graph/routing/write_router.py | 83 +++------ .../agent/langgraph_graph/write_graph.py | 116 ++++-------- api/app/core/memory/agent/utils/redis_tool.py | 173 +++++++++--------- api/app/core/memory/llm_tools/llm_client.py | 2 +- api/app/schemas/memory_agent_schema.py | 10 +- api/app/services/app_chat_service.py | 65 ++++--- api/app/services/draft_run_service.py | 13 +- api/app/services/memory_perceptual_service.py | 21 --- api/app/services/model_service.py | 138 +++++++------- api/app/services/shared_chat_service.py | 43 ++--- 12 files changed, 304 insertions(+), 530 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index f5284b46..26902b07 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -410,30 +410,6 @@ async def chat( agent_config = agent_config_4_app_release(release) if payload.stream: - # async def event_generator(): - # async for event in service.chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) async def event_generator(): async for event in app_chat_service.agnet_chat_stream( message=payload.message, @@ -459,20 +435,6 @@ async def chat( "X-Accel-Buffering": "no" } ) - # 非流式返回 - # result = await service.chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - # return success(data=conversation_schema.ChatResponse(**result)) result = await app_chat_service.agnet_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID @@ -531,48 +493,6 @@ async def chat( ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) - # 多 Agent 流式返回 - # if payload.stream: - # async def event_generator(): - # async for event in service.multi_agent_chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) - - # # 多 Agent 非流式返回 - # result = await service.multi_agent_chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - - # return success(data=conversation_schema.ChatResponse(**result)) elif app_type == AppType.WORKFLOW: config = workflow_config_4_app_release(release) if not config.id: diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..38821313 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,18 +11,14 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term -from app.db import get_db -from app.core.logging_config import get_business_logger -from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType, ModelProvider -from app.services.memory_agent_service import ( - get_end_user_connected_config, -) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType + logger = get_business_logger() @@ -226,10 +222,9 @@ class LangChainAgent: Returns: List[BaseMessage]: 消息列表 """ - messages = [] + messages:list = [SystemMessage(content=self.system_prompt)] # 添加系统提示词 - messages.append(SystemMessage(content=self.system_prompt)) # 添加历史消息 if history: @@ -293,12 +288,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, # 添加这个参数 - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: """执行对话 @@ -306,32 +296,12 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] context: 上下文信息(如知识库检索结果) + files: 多模态文件 Returns: Dict: 包含 content 和元数据的字典 """ - message_chat = message start_time = time.time() - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -419,9 +389,6 @@ class LangChainAgent: logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, - actual_config_id) response = { "content": content, "model": self.model_name, @@ -452,12 +419,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -465,6 +427,7 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 context: 上下文信息 + files: 多模态文件 Yields: str: 消息内容块 @@ -475,23 +438,6 @@ class LangChainAgent: logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info("=" * 80) - message_chat = message - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -501,17 +447,18 @@ class LangChainAgent: ) chunk_count = 0 - yielded_content = False # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") full_content = '' try: + last_event = {} async for event in self.agent.astream_events( {"messages": messages}, version="v2", config={"recursion_limit": self.max_iterations} ): + last_event = event chunk_count += 1 kind = event.get("event") @@ -525,7 +472,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -536,18 +482,15 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif kind == "on_llm_stream": # 另一种 LLM 流式事件 @@ -558,7 +501,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -569,22 +511,18 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif isinstance(chunk, str): full_content += chunk yield chunk - yielded_content = True # 记录工具调用(可选) elif kind == "on_tool_start": @@ -594,7 +532,7 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 - output_messages = event.get("data", {}).get("output", {}).get("messages", []) + output_messages = last_event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None @@ -604,9 +542,7 @@ class LangChainAgent: ) if response_meta else 0 yield total_tokens break - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, - actual_config_id) + except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 2074b6ca..74fb6bae 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id @@ -21,25 +20,6 @@ logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - """ - Write messages to RAG storage system - - Combines user and AI messages into a single string format and stores them - in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. - - Args: - end_user_id: User identifier for the conversation - user_message: User's input message content - ai_message: AI's response message content - user_rag_memory_id: RAG memory identifier for storage location - """ - # RAG mode: combine messages into string format (maintain original logic) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - - async def write( storage_type, end_user_id, @@ -118,7 +98,7 @@ async def write( logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') -async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): +async def term_memory_save(end_user_id, strategy_type, scope): """ Save long-term memory data to database @@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty to long-term memory storage. Args: - long_term_messages: Long-term message data to be saved - actual_config_id: Configuration identifier for memory settings end_user_id: User identifier for memory association - type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) scope: Scope/window size for memory processing """ with get_db_context() as db_session: @@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if not result: + logger.warning(f"No write data found for user {end_user_id}") + return + if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]: data = await format_parsing(result, "dict") chunk_data = data[:scope] if len(chunk_data) == scope: @@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty logger.info(f'写入短长期:') -"""Window-based dialogue processing""" - - async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): """ Process dialogue based on window size and write to Neo4j @@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) langchain_messages: Original message data list scope: Window size determining when to trigger long-term storage """ - scope = scope - is_end_user_id = count_store.get_sessions_count(end_user_id) - if is_end_user_id is not False: - is_end_user_id = count_store.get_sessions_count(end_user_id)[0] - redis_messages = count_store.get_sessions_count(end_user_id)[1] - if is_end_user_id and int(is_end_user_id) != int(scope): - is_end_user_id += 1 - langchain_messages += redis_messages - count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) - elif int(is_end_user_id) == int(scope): + is_end_user_has_history = count_store.get_sessions_count(end_user_id) + if is_end_user_has_history: + end_user_visit_count, redis_messages = is_end_user_has_history + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + return + end_user_visit_count += 1 + if end_user_visit_count < scope: + redis_messages.extend(langchain_messages) + count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages) + else: logger.info('写入长期记忆NEO4J') - formatted_messages = redis_messages + redis_messages.extend(langchain_messages) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: config_id = memory_config - await write( - AgentMemory_Long_Term.STORAGE_NEO4J, - end_user_id, - "", - "", - None, - end_user_id, - config_id, - formatted_messages + write_message_task.delay( + end_user_id, # end_user_id: User ID + redis_messages, # message: JSON string format message list + config_id, # config_id: Configuration ID string + AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) - count_store.update_sessions_count(end_user_id, 1, langchain_messages) - else: - count_store.save_sessions_count(end_user_id, 1, langchain_messages) - - -"""Time-based memory processing""" + count_store.update_sessions_count(end_user_id, 0, []) async def memory_long_term_storage(end_user_id, memory_config, time): @@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config return result_dict except Exception as e: - print(f"[aggregate_judgment] 发生错误: {e}") - import traceback - traceback.print_exc() + logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True) return { "is_same_event": False, diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index bf3c6597..32fc7d8a 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,49 +1,25 @@ -import asyncio -import json -import sys import warnings -from contextlib import asynccontextmanager -from langgraph.constants import END, START -from langgraph.graph import StateGraph -from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment +from app.core.memory.agent.utils.redis_tool import write_store +from app.db import get_db_context from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService +from app.services.memory_konwledges_server import write_rag warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) -if sys.platform.startswith("win"): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -@asynccontextmanager -async def make_write_graph(): - """ - Create a write graph workflow for memory operations. - - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - end_user_id: Group identifier - memory_config: MemoryConfig object containing all configuration - """ - workflow = StateGraph(WriteState) - workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "save_neo4j") - workflow.add_edge("save_neo4j", END) - - graph = workflow.compile() - - yield graph - - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): +async def long_term_storage( + long_term_type: str, + langchain_messages: list, + memory_config_id: str, + end_user_id: str, + scope: int = 6 +): """ Handle long-term memory storage with different strategies @@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l Args: long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') langchain_messages: List of messages to store - memory_config: Memory configuration identifier + memory_config_id: Memory configuration identifier end_user_id: User group identifier scope: Scope parameter for chunk-based storage (default: 6) """ - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ - aggregate_judgment - from app.core.memory.agent.utils.redis_tool import write_store + if langchain_messages is None: + langchain_messages = [] + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 + config_id=memory_config_id, # 改为整数 service_name="MemoryAgentService" ) if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: - '''Strategy 1: Dialogue window with 6 rounds of conversation''' + # Dialogue window with 6 rounds of conversation await window_dialogue(end_user_id, langchain_messages, memory_config, scope) if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: - """Time-based strategy""" + # Time-based strategy await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: - """Strategy 3: Aggregate judgment""" + # Aggregate judgment await aggregate_judgment(end_user_id, langchain_messages, memory_config) -async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): +async def write_long_term( + storage_type: str, + end_user_id: str, + messages: list[dict], + user_rag_memory_id: str, + actual_config_id: str +): """ Write long-term memory with different storage types @@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u Args: storage_type: Type of storage (RAG or traditional) end_user_id: User group identifier - message_chat: User message content - aimessages: AI response messages + messages: message list user_rag_memory_id: RAG memory identifier actual_config_id: Actual configuration ID """ - from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + message_content = [] + for message in messages: + message_content.append(f'{message.get("role")}:{message.get("content")}') + messages_string = "\n".join(message_content) + await write_rag(end_user_id, messages_string, user_rag_memory_id) else: # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE - long_term_messages = await agent_chat_messages(message_chat, aimessages) - await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) - await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) - -# async def main(): -# """主函数 - 运行工作流""" -# langchain_messages = [ -# { -# "role": "user", -# "content": "今天周五去爬山" -# }, -# { -# "role": "assistant", -# "content": "好耶" -# } -# -# ] -# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID -# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# -# -# -# if __name__ == "__main__": -# import asyncio -# asyncio.run(main()) + await long_term_storage(long_term_type=CHUNK, + langchain_messages=messages, + memory_config_id=actual_config_id, + end_user_id=end_user_id, + scope=SCOPE) + await term_memory_save(end_user_id, CHUNK, scope=SCOPE) diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index c5729628..82b22c9e 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -3,8 +3,9 @@ import uuid from app.core.config import settings from typing import List, Dict, Any, Optional, Union +from app.core.logging_config import get_logger from app.core.memory.agent.utils.redis_base import ( - serialize_messages, + serialize_messages, deserialize_messages, fix_encoding, format_session_data, @@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import ( get_current_timestamp ) - +logger = get_logger(__name__) class RedisWriteStore: """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -66,10 +67,10 @@ class RedisWriteStore: }) result = pipe.execute() - print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session_write] 保存会话失败: {e}") + logger.error(f"[save_session_write] 保存会话失败: {e}") raise e def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: @@ -99,7 +100,7 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid: # 从 key 中提取 session_id: session:write:{session_id} @@ -108,16 +109,16 @@ class RedisWriteStore: "sessionid": session_id, "messages": fix_encoding(data.get('messages', '')) }) - + if not results: return False - - print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_session_by_userid] 查询失败: {e}") + logger.error(f"[get_session_by_userid] 查询失败: {e}") return False - + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: """ 通过 end_user_id 获取所有 write 类型的会话数据 @@ -144,7 +145,7 @@ class RedisWriteStore: # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") return False # 批量获取数据 @@ -158,12 +159,12 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == end_user_id: # 从 key 中提取 session_id: session:write:{session_id} session_id = key.split(':')[-1] - + # 构建完整的会话信息 session_info = { "session_id": session_id, @@ -173,23 +174,21 @@ class RedisWriteStore: "starttime": data.get('starttime', '') } results.append(session_info) - + if not results: - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") return False - + # 按时间排序(最新的在前) results.sort(key=lambda x: x.get('starttime', ''), reverse=True) - - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") - import traceback - traceback.print_exc() + logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True) return False - def find_user_recent_sessions(self, userid: str, + def find_user_recent_sessions(self, userid: str, minutes: int = 5) -> List[Dict[str, str]]: """ 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 @@ -203,11 +202,11 @@ class RedisWriteStore: """ import time start_time = time.time() - + # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -221,7 +220,7 @@ class RedisWriteStore: for data in all_data: if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid and data.get('starttime'): # write 类型没有 aimessages,所以 Answer 为空 @@ -230,15 +229,14 @@ class RedisWriteStore: "Answer": "", "starttime": data.get('starttime', '') }) - + # 根据时间范围过滤 filtered_items = filter_by_time_range(matched_items, minutes) # 排序并移除时间字段 - result_items = sort_and_limit_results(filtered_items, limit=None) - print(result_items) + result_items = sort_and_limit_results(filtered_items) elapsed_time = time.time() - start_time - print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items @@ -258,7 +256,7 @@ class RedisWriteStore: class RedisCountStore: """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -278,7 +276,7 @@ class RedisCountStore: decode_responses=True, encoding='utf-8' ) - self.uudi = session_id + self.uuid = session_id def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: """ @@ -295,26 +293,26 @@ class RedisCountStore: session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") index_key = f'session:count:index:{end_user_id}' # 索引键 - + pipe = self.r.pipeline() pipe.hset(key, mapping={ - "id": self.uudi, + "id": self.uuid, "end_user_id": end_user_id, "count": int(count), "messages": serialize_messages(messages), "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 - + # 创建索引:end_user_id -> session_id 映射 pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) - + result = pipe.execute() - - print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + + logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") return session_id - def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool: """ 通过 end_user_id 查询访问次数统计 @@ -327,7 +325,7 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) @@ -335,35 +333,40 @@ class RedisCountStore: self.r.delete(index_key) return False except Exception as type_error: - print(f"[get_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: return False - + # 直接获取数据 key = generate_session_key(session_id, key_type="count") data = self.r.hgetall(key) - + if not data: # 索引存在但数据不存在,清理索引 self.r.delete(index_key) return False - + count = data.get('count') messages_str = data.get('messages') - + if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] - + messages: list[dict] = deserialize_messages(messages_str) + return int(count), messages + return False except Exception as e: - print(f"[get_sessions_count] 查询失败: {e}") + logger.error(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, - messages: Any) -> bool: + + def update_sessions_count( + self, + end_user_id: str, + new_count: int, + messages: Any + ) -> bool: """ 通过 end_user_id 修改访问次数统计(优化版:使用索引) @@ -378,39 +381,39 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) if key_type != 'string' and key_type != 'none': # 索引键类型错误,删除并返回 False - print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") self.r.delete(index_key) - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False except Exception as type_error: - print(f"[update_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False - + # 直接更新数据 key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - + pipe = self.r.pipeline() - pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'count', str(new_count)) pipe.hset(key, 'messages', messages_str) result = pipe.execute() - - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + + logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") return True - + except Exception as e: - print(f"[update_sessions_count] 更新失败: {e}") + logger.debug(f"[update_sessions_count] 更新失败: {e}") return False def delete_all_count_sessions(self) -> int: @@ -428,7 +431,7 @@ class RedisCountStore: class RedisSessionStore: """Redis 会话存储类,用于管理会话数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -451,9 +454,9 @@ class RedisSessionStore: self.uudi = session_id # ==================== 写入操作 ==================== - - def save_session(self, userid: str, messages: str, aimessages: str, - apply_id: str, end_user_id: str) -> str: + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: """ 写入一条会话数据,返回 session_id @@ -483,14 +486,14 @@ class RedisSessionStore: }) result = pipe.execute() - print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session] 保存会话失败: {e}") + logger.error(f"[save_session] 保存会话失败: {e}") raise e # ==================== 读取操作 ==================== - + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 @@ -520,8 +523,8 @@ class RedisSessionStore: sessions[sid] = self.get_session(sid) return sessions - def find_user_apply_group(self, sessionid: str, apply_id: str, - end_user_id: str) -> List[Dict[str, str]]: + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: """ 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 @@ -535,10 +538,10 @@ class RedisSessionStore: """ import time start_time = time.time() - + keys = self.r.keys('session:*') if not keys: - print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -556,21 +559,21 @@ class RedisSessionStore: continue if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配或完全匹配 sessionid if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append(format_session_data(data, include_time=True)) - + # 排序、限制数量并移除时间字段 result_items = sort_and_limit_results(matched_items, limit=6) elapsed_time = time.time() - start_time - print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items # ==================== 更新操作 ==================== - + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 @@ -591,7 +594,7 @@ class RedisSessionStore: return bool(results[0]) # ==================== 删除操作 ==================== - + def delete_session(self, session_id: str) -> int: """ 删除单条会话 @@ -632,7 +635,7 @@ class RedisSessionStore: keys = self.r.keys('session:*') if not keys: - print("[delete_duplicate_sessions] 没有会话数据") + logger.debug("[delete_duplicate_sessions] 没有会话数据") return 0 # 批量获取所有数据 @@ -678,7 +681,7 @@ class RedisSessionStore: deleted_count += len(batch) elapsed_time = time.time() - start_time - print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") + logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count diff --git a/api/app/core/memory/llm_tools/llm_client.py b/api/app/core/memory/llm_tools/llm_client.py index e26aba3e..49cd9434 100644 --- a/api/app/core/memory/llm_tools/llm_client.py +++ b/api/app/core/memory/llm_tools/llm_client.py @@ -56,7 +56,7 @@ class LLMClient(ABC): self.max_retries = self.config.max_retries self.timeout = self.config.timeout - logger.info( + logger.debug( f"初始化 LLM 客户端: provider={self.provider}, " f"model={self.model_name}, max_retries={self.max_retries}" ) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b4efe61d..97aa5bb5 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -17,6 +17,7 @@ class Write_UserInput(BaseModel): end_user_id: str config_id: Optional[str] = None + class AgentMemory_Long_Term(ABC): """长期记忆配置常量""" STORAGE_NEO4J = "neo4j" @@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 - TIME_SCOPE=5 -class AgentMemoryDataset(ABC): - PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余'] - NAME='用户' + TIME_SCOPE = 5 + +class AgentMemoryDataset(ABC): + PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余'] + NAME = '用户' diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..17c2f98c 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent from app.core.logging_config import get_business_logger +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig @@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import AgentRunService +from app.services.memory_agent_service import get_end_user_connected_config from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService -from app.schemas import FileType logger = get_business_logger() @@ -43,18 +44,17 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, - user_id: Optional[str] = None, + files: list[FileInput], + user_id: str, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" start_time = time.time() - config_id = None # 应用 features 配置 features_config: dict = config.features or {} @@ -93,7 +93,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, + user_id) tools.extend(kb_tools) memory_flag = False if memory: @@ -168,11 +169,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -229,6 +225,21 @@ class AppChatService: # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": result["content"]} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -264,20 +275,19 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, + files: list[FileInput], user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """聊天(流式)""" try: start_time = time.time() - config_id = None message_id = uuid.uuid4() # 应用 features 配置 @@ -319,7 +329,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config( + config.knowledge_retrieval, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -411,11 +422,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): @@ -459,7 +465,7 @@ class AppChatService: # 保存消息 human_meta = { - "files":[], + "files": [], "history_files": {} } assistant_meta = { @@ -484,6 +490,22 @@ class AppChatService: if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url + + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": full_content} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -618,7 +640,6 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) - # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index e188872f..aef54847 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig, ModelType +from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo @@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_service import ModelApiKeyService from app.services.multimodal_service import MultimodalService from app.services.tool_service import ToolService -from app.schemas import FileType logger = get_business_logger() @@ -657,11 +656,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -911,11 +905,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 3ee238e2..5c838fc0 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -243,27 +243,6 @@ class MemoryPerceptualService: memory_config: MemoryConfig, file: FileInput ): - memories = self.repository.get_by_url(file.url) - if memories: - business_logger.info(f"Perceptual memory already exists: {file.url}") - if end_user_id not in [memory.end_user_id for memory in memories]: - business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") - memory_cache = memories[0] - memory = self.repository.create_perceptual_memory( - end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType(memory_cache.perceptual_type), - file_path=memory_cache.file_path, - file_name=memory_cache.file_name, - file_ext=memory_cache.file_ext, - summary=memory_cache.summary, - meta_data=memory_cache.meta_data - ) - self.db.commit() - return memory - else: - for memory in memories: - if memory.end_user_id == uuid.UUID(end_user_id): - return memory llm, model_config = self._get_mutlimodal_client(file.type, memory_config) multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index b98674ba..c9266667 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -69,7 +69,8 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: @@ -77,21 +78,22 @@ class ModelConfigService: return model @staticmethod - def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: + def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ + ModelConfig]: """按名称模糊匹配获取模型配置列表""" return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) @staticmethod async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello", - is_omni: bool = False + db: Session, + *, + model_name: str, + provider: str, + api_key: str, + api_base: Optional[str] = None, + model_type: str = "llm", + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -158,13 +160,13 @@ class ModelConfigService: # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - + # 火山引擎使用 embed_batch,其他使用 embed_documents if provider.lower() == "volcano": vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) else: vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) - + elapsed_time = time.time() - start_time return { @@ -200,11 +202,11 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "image": # 图片生成模型验证 from app.core.models.generation import RedBearImageGenerator - + generator = RedBearImageGenerator(model_config) result = await generator.agenerate( prompt="a cute panda", @@ -212,7 +214,7 @@ class ModelConfigService: ) elapsed_time = time.time() - start_time logger.info(f"成功生成图片,结果: {result}") - + return { "valid": True, "message": "图片生成模型配置验证成功", @@ -224,21 +226,21 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "video": # 视频生成模型验证 from app.core.models.generation import RedBearVideoGenerator - + generator = RedBearVideoGenerator(model_config) result = await generator.agenerate( prompt="a cute panda playing in bamboo forest", duration=5 ) elapsed_time = time.time() - start_time - + # 视频生成是异步任务,返回任务ID task_id = result.get("task_id") if isinstance(result, dict) else None - + return { "valid": True, "message": "视频生成模型配置验证成功", @@ -265,7 +267,6 @@ class ModelConfigService: # 提取详细的错误信息 error_message = str(e) error_type = type(e).__name__ - print("=========error_message:",error_message.lower()) # 特殊处理常见的错误类型 if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): # 区域/国家限制(适用于所有提供商) @@ -354,14 +355,16 @@ class ModelConfigService: return model @staticmethod - def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """更新模型配置""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -370,25 +373,27 @@ class ModelConfigService: return model @staticmethod - async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + # 检查 API Key 关联的模型配置类型 for model_config in api_key.model_configs: # chat 和 llm 类型可以兼容 compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = model_data.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", @@ -399,7 +404,7 @@ class ModelConfigService: # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # BizCode.INVALID_PARAMETER # ) - + # 创建组合模型 model_config_data = { "tenant_id": tenant_id, @@ -418,49 +423,51 @@ class ModelConfigService: model = ModelConfigRepository.create(db, model_config_data) db.flush() - + # 关联 API Keys for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: model.api_keys.append(api_key) - + db.commit() db.refresh(model) return model @staticmethod - async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """更新组合模型""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + for model_config in api_key.model_configs: compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = existing_model.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) - + # 更新基本信息 existing_model.name = model_data.name # existing_model.type = model_data.type @@ -471,14 +478,14 @@ class ModelConfigService: existing_model.is_public = model_data.is_public if "load_balance_strategy" in model_data.model_fields_set: existing_model.load_balance_strategy = model_data.load_balance_strategy - + # 更新 API Keys 关联 existing_model.api_keys.clear() for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: existing_model.api_keys.append(api_key) - + db.commit() db.refresh(existing_model) return existing_model @@ -532,7 +539,7 @@ class ModelApiKeyService: """根据provider为多个ModelConfig创建API Key""" created_keys = [] failed_models = [] # 记录验证失败的模型 - + for model_config_id in data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: @@ -540,10 +547,10 @@ class ModelApiKeyService: data.is_omni = model_config.is_omni data.capability = model_config.capability - + # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - + # 检查是否存在API Key(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -553,7 +560,7 @@ class ModelApiKeyService: ModelApiKey.model_name == model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() - + if existing_key: # 如果已存在,重新激活并更新 if existing_key.is_active: @@ -566,14 +573,14 @@ class ModelApiKeyService: existing_key.model_name = model_name existing_key.capability = data.capability existing_key.is_omni = data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + created_keys.append(existing_key) continue - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -589,7 +596,7 @@ class ModelApiKeyService: # 记录验证失败的模型,但不抛出异常 failed_models.append(model_name) continue - + # 创建API Key api_key_data = ModelApiKeyCreate( model_config_ids=[model_config_id], @@ -606,12 +613,12 @@ class ModelApiKeyService: ) api_key_obj = ModelApiKeyRepository.create(db, api_key_data) created_keys.append(api_key_obj) - + if created_keys: db.commit() for key in created_keys: db.refresh(key) - + return created_keys, failed_models @staticmethod @@ -626,7 +633,7 @@ class ModelApiKeyService: api_key_data.is_omni = model_config.is_omni if api_key_data.capability is None: api_key_data.capability = model_config.capability - + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -650,15 +657,15 @@ class ModelApiKeyService: existing_key.model_name = api_key_data.model_name existing_key.capability = api_key_data.capability existing_key.is_omni = api_key_data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + db.commit() db.refresh(existing_key) return existing_key - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -691,7 +698,7 @@ class ModelApiKeyService: # 获取关联的模型配置以获取模型类型 if existing_api_key.model_configs: model_config = existing_api_key.model_configs[0] - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name or existing_api_key.model_name, @@ -729,15 +736,15 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: return None - + api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None - + # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) - + # 否则返回第一个 return api_keys[0] @@ -760,20 +767,19 @@ class ModelApiKeyService: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - class ModelBaseService: """基础模型服务""" @staticmethod def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: models = ModelBaseRepository.get_list(db, query) - + provider_groups = {} for m in models: model_dict = model_schema.ModelBase.model_validate(m).model_dump() if tenant_id: model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) - + provider = m.provider if provider not in provider_groups: provider_groups[provider] = { @@ -781,7 +787,7 @@ class ModelBaseService: "models": [] } provider_groups[provider]["models"].append(model_dict) - + return list(provider_groups.values()) @staticmethod @@ -823,10 +829,10 @@ class ModelBaseService: model_base = ModelBaseRepository.get_by_id(db, model_base_id) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) - + if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) - + model_config_data = { "model_id": model_base_id, "tenant_id": tenant_id, diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 0d659832..c74604a5 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -1,26 +1,24 @@ """基于分享链接的聊天服务""" -import uuid -import time import asyncio +import json +import time +import uuid from typing import Optional, Dict, Any, AsyncGenerator + +from deprecated import deprecated from sqlalchemy.orm import Session -from app.repositories.model_repository import ModelApiKeyRepository -from app.services.memory_konwledges_server import write_rag +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException, ResourceNotFoundException +from app.core.logging_config import get_business_logger +from app.models import MultiAgentConfig from app.models import ReleaseShare, AppRelease, Conversation +from app.repositories import knowledge_repository from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService -from app.services.release_share_service import ReleaseShareService -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger from app.services.multi_agent_service import MultiAgentService -from app.models import MultiAgentConfig -from app.repositories import knowledge_repository -import json -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task +from app.services.release_share_service import ReleaseShareService logger = get_business_logger() @@ -118,6 +116,7 @@ class SharedChatService: return conversation + @deprecated("Use the chat method under app_chat_service instead.") async def chat( self, share_token: str, @@ -136,10 +135,7 @@ class SharedChatService: config_id = actual_config_id from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool - from app.services.model_parameter_merger import ModelParameterMerger from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey start_time = time.time() actual_config_id = None @@ -273,11 +269,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ) # 保存消息 @@ -324,6 +315,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def chat_stream( self, share_token: str, @@ -341,8 +333,6 @@ class SharedChatService: from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey import json start_time = time.time() @@ -486,11 +476,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ): if isinstance(chunk, int): total_tokens = chunk @@ -585,6 +570,7 @@ class SharedChatService: return conversations, total + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat( self, share_token: str, @@ -680,6 +666,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat_stream( self, share_token: str, From 5703fc0cb4c5b2b9957d14091afccc54ea5a0b06 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 13:45:17 +0800 Subject: [PATCH 084/120] [fix] Set the page for the nodes to be forgotten --- .../controllers/memory_forget_controller.py | 95 +++++++++++ api/app/schemas/memory_storage_schema.py | 17 +- api/app/services/memory_forget_service.py | 156 +++++++++++++++--- 3 files changed, 240 insertions(+), 28 deletions(-) diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py index 2b5ef72f..51ce92b3 100644 --- a/api/app/controllers/memory_forget_controller.py +++ b/api/app/controllers/memory_forget_controller.py @@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import ( ForgettingCurveRequest, ForgettingCurveResponse, ForgettingCurvePoint, + PendingNodesResponse, ) from app.schemas.response_schema import ApiResponse from app.services.memory_forget_service import MemoryForgetService @@ -308,6 +309,100 @@ async def get_forgetting_stats( return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) +@router.get("/pending-nodes", response_model=ApiResponse) +async def get_pending_nodes( + end_user_id: str, + page: int = 1, + pagesize: int = 10, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + 此接口独立分页,与 /stats 接口分离。 + + Args: + end_user_id: 组ID(即 end_user_id,必填) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含待遗忘节点列表和分页信息的响应 + + Examples: + - 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10 + - 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20 + + Notes: + - page 从1开始,pagesize 必须大于0 + - 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}} + """ + workspace_id = current_user.current_workspace_id + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 验证 end_user_id 必填 + if not end_user_id: + api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id") + return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required") + + # 通过 end_user_id 获取关联的 config_id + try: + from app.services.memory_agent_service import get_end_user_connected_config + + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + config_id = resolve_config_id(config_id, db) + + if config_id is None: + api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None") + + api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}") + except ValueError as e: + api_logger.warning(f"获取终端用户配置失败: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + except Exception as e: + api_logger.error(f"获取终端用户配置时发生错误: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e)) + + # 验证分页参数 + if page < 1: + return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1") + if pagesize < 1: + return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: " + f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}" + ) + + try: + # 调用服务层获取待遗忘节点列表 + result = await forget_service.get_pending_nodes( + db=db, + end_user_id=end_user_id, + config_id=config_id, + page=page, + pagesize=pagesize + ) + + # 构建响应 + response_data = PendingNodesResponse(**result) + + return success(data=response_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"获取待遗忘节点列表失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e)) + + @router.post("/forgetting_curve", response_model=ApiResponse) async def get_forgetting_curve( request: ForgettingCurveRequest, diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 711b6de9..bfcf6337 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel): last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)") +class PageInfo(BaseModel): + """分页信息模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + page: int = Field(..., description="当前页码(从1开始)") + pagesize: int = Field(..., description="每页数量") + total: int = Field(..., description="总记录数") + hasnext: bool = Field(..., description="是否有下一页") + + +class PendingNodesResponse(BaseModel): + """待遗忘节点列表响应模型(独立分页接口)""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表") + page: PageInfo = Field(..., description="分页信息") + + class ForgettingStatsResponse(BaseModel): """遗忘引擎统计信息响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel): node_distribution: Dict[str, int] = Field(..., description="节点类型分布") recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") - pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") timestamp: int = Field(..., description="统计时间(时间戳)") diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 5122ae02..2d91f025 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -203,29 +203,36 @@ class MemoryForgetService: connector: Neo4jConnector, end_user_id: str, forgetting_threshold: float, - min_days_since_access: int - ) -> list[Dict[str, Any]]: + min_days_since_access: int, + page: Optional[int] = None, + pagesize: Optional[int] = None + ) -> Dict[str, Any]: """ 获取待遗忘节点列表 - - 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) - + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。 + Args: connector: Neo4j 连接器 end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 - + page: 页码(可选,从1开始) + pagesize: 每页数量(可选) + Returns: - list: 待遗忘节点列表 + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息(分页时) """ from datetime import timedelta - + # 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区) min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') - - query = """ + + # 基础查询(用于获取总数) + count_query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) AND n.end_user_id = $end_user_id @@ -233,10 +240,22 @@ class MemoryForgetService: AND n.activation_value < $threshold AND n.last_access_time IS NOT NULL AND datetime(n.last_access_time) < datetime($min_access_time_str) - RETURN + RETURN count(n) as total + """ + + # 数据查询 + data_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + AND n.end_user_id = $end_user_id + AND n.activation_value IS NOT NULL + AND n.activation_value < $threshold + AND n.last_access_time IS NOT NULL + AND datetime(n.last_access_time) < datetime($min_access_time_str) + RETURN elementId(n) as node_id, labels(n)[0] as node_type, - CASE + CASE WHEN n:Statement THEN n.statement WHEN n:ExtractedEntity THEN n.name WHEN n:MemorySummary THEN n.content @@ -246,15 +265,31 @@ class MemoryForgetService: n.last_access_time as last_access_time ORDER BY n.activation_value ASC """ - + + # 如果启用分页,添加 SKIP 和 LIMIT + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + data_query += " SKIP $skip LIMIT $limit" + params = { 'end_user_id': end_user_id, 'threshold': forgetting_threshold, 'min_access_time_str': min_access_time_str } - - results = await connector.execute_query(query, **params) - + + # 获取总数(分页时需要) + total = 0 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + count_results = await connector.execute_query(count_query, **params) + if count_results: + total = count_results[0]['total'] + + # 添加分页参数 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + params['skip'] = (page - 1) * pagesize + params['limit'] = pagesize + + results = await connector.execute_query(data_query, **params) + pending_nodes = [] for result in results: # 将节点类型标签转换为小写 @@ -263,7 +298,7 @@ class MemoryForgetService: node_type_label = 'entity' elif node_type_label == 'memorysummary': node_type_label = 'summary' - + # 将 Neo4j DateTime 对象转换为时间戳(毫秒) last_access_time = result['last_access_time'] last_access_dt = convert_neo4j_datetime_to_python(last_access_time) @@ -274,7 +309,7 @@ class MemoryForgetService: last_access_timestamp = int(last_access_dt.timestamp() * 1000) else: last_access_timestamp = 0 - + pending_nodes.append({ 'node_id': str(result['node_id']), 'node_type': node_type_label, @@ -282,8 +317,20 @@ class MemoryForgetService: 'activation_value': result['activation_value'], 'last_access_time': last_access_timestamp }) - - return pending_nodes + + # 构建返回结果 + result: Dict[str, Any] = {'items': pending_nodes} + + # 如果启用分页,添加分页信息 + if page is not None and pagesize is not None and page > 0 and pagesize > 0: + result['page'] = { + 'page': page, + 'pagesize': pagesize, + 'total': total, + 'hasnext': (page * pagesize) < total + } + + return result async def trigger_forgetting_cycle( self, @@ -656,24 +703,79 @@ class MemoryForgetService: except Exception as e: api_logger.error(f"获取待遗忘节点失败: {str(e)}") # 失败时返回空列表,不影响主流程 - - # 构建统计信息 + + # 构建统计信息(不包含 pending_nodes,已分离到独立接口) stats = { 'activation_metrics': activation_metrics, 'node_distribution': node_distribution, 'recent_trends': recent_trends, - 'pending_nodes': pending_nodes, 'timestamp': int(datetime.now().timestamp() * 1000) } - + api_logger.info( f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " - f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" + f"trend_days={len(recent_trends)}" ) - + return stats - + + async def get_pending_nodes( + self, + db: Session, + end_user_id: str, + config_id: Optional[UUID] = None, + page: int = 1, + pagesize: int = 10 + ) -> Dict[str, Any]: + """ + 获取待遗忘节点列表(独立分页接口) + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。 + + Args: + db: 数据库会话 + end_user_id: 组ID(必填) + config_id: 配置ID(可选,用于获取遗忘阈值) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + + Returns: + dict: 包含待遗忘节点列表和分页信息的字典 + - items: 待遗忘节点列表 + - page: 分页信息 + """ + # 获取遗忘引擎组件 + _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + + connector = forgetting_scheduler.connector + forgetting_threshold = config['forgetting_threshold'] + + # 验证 min_days_since_access 配置值 + min_days = config.get('min_days_since_access') + if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: + api_logger.warning( + f"min_days_since_access 配置无效: {min_days}, 使用默认值 7" + ) + min_days = 7 + + # 调用内部方法获取分页数据 + pending_nodes_result = await self._get_pending_forgetting_nodes( + connector=connector, + end_user_id=end_user_id, + forgetting_threshold=forgetting_threshold, + min_days_since_access=int(min_days), + page=page, + pagesize=pagesize + ) + + api_logger.info( + f"成功获取待遗忘节点列表: end_user_id={end_user_id}, " + f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}" + ) + + return pending_nodes_result + async def get_forgetting_curve( self, db: Session, From c89eccf8fe25c039ac93011fc15ca312b8ef7e0c Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 30 Mar 2026 14:55:04 +0800 Subject: [PATCH 085/120] fix(public_share_chat): History conversation message returns audio status --- api/app/controllers/public_share_controller.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index f5284b46..134379fb 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.workflow_service import WorkflowService +from app.models.file_metadata_model import FileMetadata from app.utils.app_config_utils import workflow_config_4_app_release, \ agent_config_4_app_release, multi_agent_config_4_app_release @@ -259,8 +260,19 @@ def get_conversation( conv_service = ConversationService(db) messages = conv_service.get_messages(conversation_id) - # 构建响应 - conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump() + # 为 assistant 消息查询 audio_url 状态 + for m in messages: + if m.role == "assistant" and m.meta_data: + audio_url = m.meta_data.get("audio_url") + if audio_url: + try: + file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1]) + file_meta = db.get(FileMetadata, file_id) + m.meta_data["audio_status"] = file_meta.status if file_meta else "unknown" + except (ValueError, IndexError): + m.meta_data["audio_status"] = "unknown" + + conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json") conv_dict["messages"] = [ conversation_schema.Message.model_validate(m) for m in messages ] From e59a215078f9b38a8813021bbf67058b10abb808 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 15:03:58 +0800 Subject: [PATCH 086/120] fix(web): app source key change --- web/src/views/ApplicationManagement/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/ApplicationManagement/index.tsx b/web/src/views/ApplicationManagement/index.tsx index 4d49635c..3b444a3d 100644 --- a/web/src/views/ApplicationManagement/index.tsx +++ b/web/src/views/ApplicationManagement/index.tsx @@ -216,7 +216,7 @@ const ApplicationManagement: React.FC = () => { 'rb:text-[#155EEF]': key === 'type', })}> {key === 'source' && item.is_shared - ? t('application.shared') + ? item.source_workspace_name : key === 'source' && !item.is_shared ? t('application.configuration') : key === 'created_at' From 8285250096c6bcda083522a8cfedf018e714c0fc Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 30 Mar 2026 15:06:35 +0800 Subject: [PATCH 087/120] fix(public_share_chat): History conversation message returns audio status --- .../controllers/public_share_controller.py | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 134379fb..2b224e28 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -260,17 +260,39 @@ def get_conversation( conv_service = ConversationService(db) messages = conv_service.get_messages(conversation_id) - # 为 assistant 消息查询 audio_url 状态 - for m in messages: + file_ids = [] + message_file_id_map = {} + + # 第一次遍历:解析 audio_url,收集所有有效的 file_id + for idx, m in enumerate(messages): if m.role == "assistant" and m.meta_data: audio_url = m.meta_data.get("audio_url") - if audio_url: - try: - file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1]) - file_meta = db.get(FileMetadata, file_id) - m.meta_data["audio_status"] = file_meta.status if file_meta else "unknown" - except (ValueError, IndexError): - m.meta_data["audio_status"] = "unknown" + if not audio_url: + continue + try: + file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1]) + except (ValueError, IndexError): + # audio_url 无法解析为 UUID,标记为 unknown + m.meta_data["audio_status"] = "unknown" + continue + + file_ids.append(file_id) + message_file_id_map[idx] = file_id + + # 批量查询所有相关的 FileMetadata + file_status_map = {} + if file_ids: + file_metas = ( + db.query(FileMetadata) + .filter(FileMetadata.id.in_(set(file_ids))) + .all() + ) + file_status_map = {fm.id: fm.status for fm in file_metas} + + # 第二次遍历:将查询结果映射回消息 + for idx, file_id in message_file_id_map.items(): + m = messages[idx] + m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown") conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json") conv_dict["messages"] = [ From dae7431075a2b369bb520de8d30813f8854c9903 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 15:39:53 +0800 Subject: [PATCH 088/120] [fix] Refusing the user, I went to "other_name" --- .../extraction_orchestrator.py | 45 +++++++++++++------ .../prompt/prompts/extract_triplet.jinja2 | 6 +++ api/app/services/user_memory_service.py | 12 +++++ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index f6a143cd..b20112a2 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1405,7 +1405,8 @@ class ExtractionOrchestrator: logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") else: first_alias = current_aliases[0].strip() if current_aliases else "" - if first_alias: + # 确保 first_alias 不是占位名称 + if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES: db.add(EndUserInfo( end_user_id=end_user_uuid, other_name=first_alias, @@ -1421,29 +1422,33 @@ class ExtractionOrchestrator: + # 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中 + USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'} + def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) - 这个方法直接返回 LLM 提取的别名列表,不做任何修改。 + 这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。 第一个别名将被用作 other_name。 Args: entity_nodes: 实体节点列表 Returns: - 别名列表(保持 LLM 提取的原始顺序) + 别名列表(保持 LLM 提取的原始顺序,已过滤占位名称) """ - USER_NAMES = {'用户', '我', 'User', 'I'} for entity in entity_nodes: - if getattr(entity, 'name', '').strip() in USER_NAMES: + if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES: aliases = getattr(entity, 'aliases', []) or [] - logger.debug(f"提取到用户别名(原始顺序): {aliases}") - return aliases + # 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}") + return filtered return [] async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: - """从 Neo4j 查询用户实体的完整 aliases 列表""" + """从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)""" cypher = """ MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I'] @@ -1457,7 +1462,10 @@ class ExtractionOrchestrator: aliases = result[0].get('aliases') or [] if not aliases: logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") - return aliases + return [] + # 过滤掉占位名称,防止历史脏数据传播 + filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES] + return filtered def _resolve_other_name( self, @@ -1469,14 +1477,25 @@ class ExtractionOrchestrator: 决定 other_name 是否需要更新,返回新值;无需更新返回 None。 决策规则: - - 为空 → 用本次对话第一个别名 + - 为空或为占位名称 → 用本次对话第一个别名 - 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除) - 否则 → 保持不变(返回 None) + + 注意:返回值不允许是占位名称("用户"、"我"、"User"、"I") """ - if not current or not current.strip(): - return current_aliases[0].strip() if current_aliases else None + # 当前值为空或为占位名称时,需要更新 + if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES: + candidate = current_aliases[0].strip() if current_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate if current not in neo4j_aliases: - return neo4j_aliases[0].strip() if neo4j_aliases else None + candidate = neo4j_aliases[0].strip() if neo4j_aliases else None + # 确保候选值不是占位名称 + if candidate and candidate in self.USER_PLACEHOLDER_NAMES: + return None + return candidate return None diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index f9f2f45c..6605532d 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement. {% if language == "zh" %} - 用户实体的 name 字段:使用 "用户" 或 "我" - 用户的真实姓名:放入 aliases + - **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等** - 示例: * "我叫李明" → name="用户", aliases=["李明"] + * ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases) + * ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases) {% else %} - User entity name field: use "User" or "I" - User's real name: put in aliases + - **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.** - Examples: * "I'm John" → name="User", aliases=["John"] + * ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases) + * ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases) {% endif %} diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 942e01a0..c6743ff2 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -472,6 +472,18 @@ class UserMemoryService: # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'} + # 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中 + _user_placeholder_names = {'用户', '我', 'User', 'I'} + + # 过滤 other_name:不允许设置为占位名称 + if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names: + logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name") + del update_data['other_name'] + + # 过滤 aliases:移除占位名称 + if 'aliases' in update_data and update_data['aliases']: + update_data['aliases'] = [a for a in update_data['aliases'] if a.strip() not in _user_placeholder_names] + # 检查是否更新了 aliases 字段 aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases From 64a73c41d639a2e6b06f35a69aedf14b322b82d5 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 15:49:58 +0800 Subject: [PATCH 089/120] fix(web): chat history audio add status --- web/src/components/Chat/ChatContent.tsx | 24 ++++--- web/src/views/ApplicationConfig/Logs.tsx | 2 +- web/src/views/Conversation/index.tsx | 86 +++++++++++++++--------- 3 files changed, 69 insertions(+), 43 deletions(-) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index ddb25838..0276916f 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -37,11 +37,11 @@ const ChatContent: FC = ({ const prevDataLengthRef = useRef(data.length); const isScrolledToBottomRef = useRef(true); const audioRef = useRef(null) - const [playingIndex, setPlayingIndex] = useState(null) + const [playingIndex, setPlayingIndex] = useState(null) - const handlePlay = (index: number, audio_url: string, audio_status?: string) => { - if (audio_status !== 'completed' && !audio_status) return - if (playingIndex === index) { + const handlePlay = (audio_url: string, audio_status?: string) => { + if (audio_status !== 'completed' && typeof audio_status === 'string') return + if (playingIndex === audio_url) { audioRef.current?.pause() setPlayingIndex(null) return @@ -52,7 +52,7 @@ const ChatContent: FC = ({ const audio = new Audio(audio_url) audioRef.current = audio audio.play() - setPlayingIndex(index) + setPlayingIndex(audio_url) audio.onended = () => setPlayingIndex(null) } @@ -79,12 +79,16 @@ const ChatContent: FC = ({ } }; }, []); - + // Auto-scroll to bottom when data changes to show latest messages // When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom // When data array length changes, auto-scroll to bottom // If already scrolled to bottom, will auto-scroll to bottom useEffect(() => { + if (playingIndex && !data.some(item => item.meta_data?.audio_url === playingIndex)) { + audioRef.current?.pause() + setPlayingIndex(null) + } setTimeout(() => { if (scrollContainerRef.current) { // Auto-scroll if data length changed OR user is currently at bottom @@ -204,16 +208,16 @@ const ChatContent: FC = ({ {item.meta_data?.audio_url && <> - {playingIndex !== index && item.meta_data?.audio_status === 'pending' + {playingIndex !== item.meta_data?.audio_url && item.meta_data?.audio_status === 'pending' ? - : playingIndex !== index + : playingIndex !== item.meta_data?.audio_url ? handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> + })} onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> :
handlePlay(index, item.meta_data?.audio_url!, item.meta_data?.audio_status)} + onClick={() => handlePlay(item.meta_data?.audio_url!, item.meta_data?.audio_status)} /> } diff --git a/web/src/views/ApplicationConfig/Logs.tsx b/web/src/views/ApplicationConfig/Logs.tsx index 49a5bbd6..cf56059c 100644 --- a/web/src/views/ApplicationConfig/Logs.tsx +++ b/web/src/views/ApplicationConfig/Logs.tsx @@ -34,7 +34,7 @@ const Statistics: FC = () => { className: 'rb:text-[#212332]' }, { - title: t('application.createTime'), + title: t('application.created_at'), dataIndex: 'created_at', key: 'created_at', render: (createdAt: string) => formatDateTime(createdAt, 'YYYY-MM-DD HH:mm:ss'), diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 80394317..d4d25070 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -64,6 +64,13 @@ const Conversation: FC = () => { const [config, setConfig] = useState>({}) const [audioStatusMap, setAudioStatusMap] = useState>({}) + useEffect(() => { + return () => { + audioPollingRef.current.forEach((timer) => clearInterval(timer)) + audioPollingRef.current.clear() + } + }, []) + useEffect(() => { const shareToken = localStorage.getItem(`shareToken_${token}`) setShareToken(shareToken) @@ -144,13 +151,29 @@ const Conversation: FC = () => { } useEffect(() => { - audioPollingRef.current.forEach((timer) => clearInterval(timer)) - audioPollingRef.current.clear() if (conversation_id) { getConversationDetail(token as string, conversation_id) .then(res => { const response = res as { messages: ChatItem[] } - setChatList(response?.messages || []) + const messages = response?.messages || [] + const historyAudioUrls = new Set(messages.map(m => m.meta_data?.audio_url).filter(Boolean)) + audioPollingRef.current.forEach((timer, key) => { + if (!historyAudioUrls.has(key)) { + clearInterval(timer) + audioPollingRef.current.delete(key) + } + }) + messages.forEach(msg => { + if (msg.role === 'assistant' && msg.meta_data?.audio_url && msg.meta_data?.audio_status === 'pending') { + startAudioPolling(msg.meta_data.audio_url, msg.meta_data.audio_url) + } + }) + setChatList(messages.map(msg => { + if (msg.role === 'assistant' && msg.meta_data?.audio_url && audioPollingRef.current.has(msg.meta_data.audio_url)) { + return { ...msg, meta_data: { ...msg.meta_data, audio_status: 'pending' } } + } + return msg + })) }) } else { if (features?.opening_statement?.statement) { @@ -228,6 +251,28 @@ const Conversation: FC = () => { })) }, [audioStatusMap, chatList.length]) + const startAudioPolling = (audioUrl: string, idToPoll: string) => { + if (audioPollingRef.current.has(idToPoll)) return + const fileId = audioUrl.split('/').pop() + if (!fileId) return + const timer = setInterval(() => { + getFileStatusById(fileId) + .then(res => { + const { status } = res as { status: string } + if (status && status !== 'pending') { + setAudioStatusMap(prev => ({ ...prev, [idToPoll]: status })) + clearInterval(audioPollingRef.current.get(idToPoll)) + audioPollingRef.current.delete(idToPoll) + } + }) + .catch(() => { + clearInterval(audioPollingRef.current.get(idToPoll)) + audioPollingRef.current.delete(idToPoll) + }) + }, 2000) + audioPollingRef.current.set(idToPoll, timer) + } + /** Send message and handle streaming response */ const handleSend = (msg?: string) => { if (!token || !shareToken) return @@ -287,35 +332,8 @@ const Conversation: FC = () => { const { file_id } = item.data as { file_id?: string } const idToPoll = file_id || audio_url || '' const fileId = audio_url.split('/').pop() - if (fileId && idToPoll && !audioPollingRef.current.has(idToPoll)) { - - const timer = setInterval(() => { - getFileStatusById(fileId) - .then(res => { - const { status } = res as { status: string } - if (status && status !== 'pending') { - setAudioStatusMap(prev => ({ - ...prev, - [idToPoll]: status - })) - clearInterval(audioPollingRef.current.get(idToPoll)) - audioPollingRef.current.delete(idToPoll) - getHistory(true) - if (currentConversationId && currentConversationId !== conversation_id) { - setConversationId(currentConversationId) - } - } - }) - .catch(() => { - clearInterval(audioPollingRef.current.get(idToPoll)) - audioPollingRef.current.delete(idToPoll) - getHistory(true) - if (currentConversationId && currentConversationId !== conversation_id) { - setConversationId(currentConversationId) - } - }) - }, 2000) - audioPollingRef.current.set(idToPoll, timer) + if (fileId && idToPoll) { + startAudioPolling(audio_url, idToPoll) } } else { getHistory(true) @@ -327,6 +345,10 @@ const Conversation: FC = () => { updateAssistantMessage(content, audio_url, undefined, citations) } setLoading(false) + getHistory(true) + if (currentConversationId && currentConversationId !== conversation_id) { + setConversationId(currentConversationId) + } break } }) From c0cd2373c0be88a53c5544216e0c1601a8e36b36 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 15:51:30 +0800 Subject: [PATCH 090/120] [fix] Added type checking with isinstance(a, str) and filtering out empty strings with a.strip() --- api/app/services/user_memory_service.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index c6743ff2..ab51d922 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -480,9 +480,12 @@ class UserMemoryService: logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name") del update_data['other_name'] - # 过滤 aliases:移除占位名称 + # 过滤 aliases:移除占位名称和非字符串值 if 'aliases' in update_data and update_data['aliases']: - update_data['aliases'] = [a for a in update_data['aliases'] if a.strip() not in _user_placeholder_names] + update_data['aliases'] = [ + a for a in update_data['aliases'] + if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names + ] # 检查是否更新了 aliases 字段 aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases From e9ad13504ae4fbcc78ed3d81a39b0f660468653f Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 16:00:49 +0800 Subject: [PATCH 091/120] fix(memory,task): add Redis fair lock for ordered memory writes --- .../core/memory/llm_tools/openai_client.py | 2 +- api/app/tasks.py | 40 ++++-- api/app/utils/redis_lock.py | 133 +++++++++++++++--- 3 files changed, 145 insertions(+), 30 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..c70fef5f 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.info(f"OpenAI 客户端初始化完成: type={type_}") + logger.debug(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index d5f09a29..0e909fcc 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,4 @@ import asyncio -import hashlib import os import re import shutil @@ -38,12 +37,10 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema -from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService -from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisLock +from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) @@ -1148,8 +1145,28 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result + redis_client = get_sync_redis_client() + lock = None + if redis_client is not None: + lock = RedisFairLock( + key=f"memory_write:{end_user_id}", + redis_client=redis_client, + expire=120, + timeout=300, + auto_renewal=True, + ) + if not lock.acquire(): + logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") + return { + "status": "SKIPPED", + "error": "acquire lock timeout", + "end_user_id": end_user_id, + "config_id": str(config_id), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + try: - # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1158,7 +1175,6 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1199,9 +1215,12 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } - - -# unused task + finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. @@ -2879,3 +2898,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index 99f62d84..a86ba46e 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,6 +1,7 @@ import redis import uuid import time +import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -10,45 +11,136 @@ else end """ +RENEW_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) +else + return 0 +end +""" -class RedisLock: +CLEANUP_DEAD_HEAD_SCRIPT = """ +local queue_key = KEYS[1] +local lock_key = KEYS[2] + +local first = redis.call("lindex", queue_key, 0) +if not first then + return 0 +end + +if redis.call("exists", lock_key) == 1 then + return 0 +end + +redis.call("lpop", queue_key) +return 1 +""" + +SAFE_RELEASE_QUEUE_SCRIPT = """ +local queue_key = KEYS[1] +local value = ARGV[1] + +local first = redis.call("lindex", queue_key, 0) +if first == value then + redis.call("lpop", queue_key) + return 1 +end +return 0 +""" + + +def _ensure_str(val): + """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" + if val is None: + return None + if isinstance(val, bytes): + return val.decode("utf-8") + return str(val) + + +class RedisFairLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 60, - retry_interval: float = 0.1, - timeout: float = 30 - + expire: int = 30, + retry_interval: float = 0.05, + timeout: float = 600, + auto_renewal: bool = True ): self.key = key - self.expire = expire + self.queue_key = f"{key}:queue" self.value = str(uuid.uuid4()) - self._locked = False + self.expire = expire self.retry_interval = retry_interval self.timeout = timeout - self.redis_client = redis_client + self.redis = redis_client + self._locked = False + self.auto_renewal = auto_renewal + self._renew_thread = None + self._stop_renew = threading.Event() - def acquire(self) -> bool: + def acquire(self): start = time.time() + + self.redis.rpush(self.queue_key, self.value) + while True: - ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) - if ok: - self._locked = True - return True - if time.time() - start >= self.timeout: + first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + + if first == self.value: + ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) + if ok: + self._locked = True + + if self.auto_renewal: + self._start_renewal() + return True + + if first: + self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + + if time.time() - start > self.timeout: + self.redis.lrem(self.queue_key, 0, self.value) return False + time.sleep(self.retry_interval) + def _renewal_loop(self): + while not self._stop_renew.is_set(): + time.sleep(self.expire / 3) + if self._stop_renew.is_set(): + break + + self.redis.eval( + RENEW_SCRIPT, + 1, + self.key, + self.value, + str(self.expire) + ) + + def _start_renewal(self): + self._stop_renew = threading.Event() + self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) + self._renew_thread.start() + + def _stop_renewal(self): + self._stop_renew.set() + if self._renew_thread: + self._renew_thread.join(timeout=1) + def release(self): if not self._locked: return - self.redis_client.eval( - UNLOCK_SCRIPT, - 1, - self.key, - self.value - ) + + if self.auto_renewal: + self._stop_renewal() + + self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) + + self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._locked = False def __enter__(self): @@ -59,3 +151,4 @@ class RedisLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + From c7b51e7ad8a50c15504835d1a5fe5633c35ab3f8 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 16:13:45 +0800 Subject: [PATCH 092/120] fix(web): ui --- web/src/views/MemoryConversation/index.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/views/MemoryConversation/index.tsx b/web/src/views/MemoryConversation/index.tsx index c6428669..c33bd0e5 100644 --- a/web/src/views/MemoryConversation/index.tsx +++ b/web/src/views/MemoryConversation/index.tsx @@ -174,8 +174,8 @@ const MemoryConversation: FC = () => { /> - -
+ + { - + Date: Mon, 30 Mar 2026 16:31:23 +0800 Subject: [PATCH 093/120] fix(web): BodyWrapper add init height class --- web/src/components/Empty/BodyWrapper.tsx | 9 +++++---- web/src/views/UserMemory/index.tsx | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/web/src/components/Empty/BodyWrapper.tsx b/web/src/components/Empty/BodyWrapper.tsx index 067b743c..5d23b55c 100644 --- a/web/src/components/Empty/BodyWrapper.tsx +++ b/web/src/components/Empty/BodyWrapper.tsx @@ -24,16 +24,17 @@ interface BodyWrapperProps { /** Whether to show loading state */ loading?: boolean /** Whether the content is empty */ - empty: boolean + empty: boolean; + className?: string; } -const BodyWrapper: FC = ({ children, loading = false, empty }) => { +const BodyWrapper: FC = ({ children, loading = false, empty, className = 'rb:max-h-[calc(100%-48px)]!' }) => { // Show loading spinner while data is being fetched if (loading) { - return + return } // Show empty state when no data is available if (!loading && empty) { - return + return } // Render actual content when data is loaded and available return children diff --git a/web/src/views/UserMemory/index.tsx b/web/src/views/UserMemory/index.tsx index 96da9dec..1372929b 100644 --- a/web/src/views/UserMemory/index.tsx +++ b/web/src/views/UserMemory/index.tsx @@ -78,7 +78,7 @@ export default function UserMemory() { }, [search, data]) return ( -
+ <>
@@ -137,6 +137,6 @@ export default function UserMemory() { })} - + ); } \ No newline at end of file From f7e89af9d2b1e7e8b3517bfb38b8ea78324202ba Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Mon, 30 Mar 2026 16:44:43 +0800 Subject: [PATCH 094/120] fix(app): memory config initialization for end users - Add memory_config_id extraction and assignment when creating new end users in public share chat - Introduce get_or_create_end_user_with_config method to handle memory config setup in single transaction - Add batch_update_memory_config_id_by_app method for bulk updating end user memory configs - Rename _update_endusers_memory_config_by_workspace to _update_endusers_memory_config_by_app for correct scope - Update app publish flow to use app_id instead of workspace_id for memory config updates - Remove unused actual_end_user_id variable in langchain_agent - Ensures end users are properly associated with memory configs on creation and during app updates --- .../controllers/public_share_controller.py | 10 ++ api/app/core/agent/langchain_agent.py | 1 - api/app/repositories/end_user_repository.py | 121 ++++++++++++++++++ api/app/services/app_service.py | 16 +-- 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 2b224e28..fc2916ed 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -354,6 +354,16 @@ async def chat( other_id=other_id, original_user_id=user_id ) + + # Only extract and set memory_config_id when the end user doesn't have one yet + if not new_end_user.memory_config_id: + from app.services.memory_config_service import MemoryConfigService + memory_config_service = MemoryConfigService(db) + memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {}) + if memory_config_id: + new_end_user.memory_config_id = memory_config_id + db.commit() + db.refresh(new_end_user) end_user_id = str(new_end_user.id) # appid = share.app_id diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..7314ab5f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -329,7 +329,6 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 3c1dd16f..aad80707 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -132,6 +132,82 @@ class EndUserRepository: db_logger.error(f"获取或创建终端用户时出错: {str(e)}") raise + def get_or_create_end_user_with_config( + self, + app_id: Optional[uuid.UUID], + workspace_id: uuid.UUID, + other_id: str, + memory_config_id: Optional[uuid.UUID] = None, + other_name: Optional[str] = None + ) -> EndUser: + """获取或创建终端用户,并在单次事务中关联记忆配置。 + + 与 get_or_create_end_user 类似,但额外支持在创建/获取时 + 一并设置 memory_config_id,避免多次提交。 + + Args: + app_id: 应用ID(可为 None) + workspace_id: 工作空间ID + other_id: 第三方ID + memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置) + other_name: 用户名称(用于创建 EndUserInfo) + + Returns: + EndUser: 终端用户对象(已关联记忆配置) + """ + try: + end_user = ( + self.db.query(EndUser) + .filter( + EndUser.workspace_id == workspace_id, + EndUser.other_id == other_id + ) + .order_by(EndUser.created_at.asc()) + .first() + ) + + if end_user: + db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}") + if app_id is not None: + end_user.app_id = app_id + if memory_config_id and not end_user.memory_config_id: + end_user.memory_config_id = memory_config_id + self.db.commit() + self.db.refresh(end_user) + return end_user + + # 创建新用户 + end_user = EndUser( + app_id=app_id, + workspace_id=workspace_id, + other_id=other_id, + memory_config_id=memory_config_id, + ) + self.db.add(end_user) + self.db.flush() + + end_user_info = EndUserInfo( + end_user_id=end_user.id, + other_name=other_name or "", + aliases=[], + meta_data={} + ) + self.db.add(end_user_info) + + self.db.commit() + self.db.refresh(end_user) + + db_logger.info( + f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, " + f"memory_config_id={memory_config_id}" + ) + return end_user + + except Exception as e: + self.db.rollback() + db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}") + raise + def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据ID获取终端用户(用于缓存操作) @@ -515,6 +591,51 @@ class EndUserRepository: ) raise + def batch_update_memory_config_id_by_app( + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID + ) -> int: + """批量更新应用下所有终端用户的 memory_config_id + + Args: + app_id: 应用ID + memory_config_id: 新的记忆配置ID + + Returns: + int: 更新的终端用户数量 + + Raises: + Exception: 数据库操作失败时抛出 + """ + try: + from sqlalchemy import update + + stmt = ( + update(EndUser) + .where(EndUser.app_id == app_id) + .values(memory_config_id=memory_config_id) + ) + + result = self.db.execute(stmt) + self.db.commit() + + updated_count = result.rowcount + + db_logger.info( + f"批量更新终端用户记忆配置: app_id={app_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + + return updated_count + except Exception as e: + self.db.rollback() + db_logger.error( + f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + f"memory_config_id={memory_config_id}, error={str(e)}" + ) + raise + def count_by_memory_config_id( self, memory_config_id: uuid.UUID diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index e1164206..377f9479 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1682,15 +1682,15 @@ class AppService: return config.config_id - def _update_endusers_memory_config_by_workspace( + def _update_endusers_memory_config_by_app( self, - workspace_id: uuid.UUID, + app_id: uuid.UUID, memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id Args: - workspace_id: 工作空间ID + app_id: 应用ID memory_config_id: 新的记忆配置ID Returns: @@ -1699,8 +1699,8 @@ class AppService: from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(self.db) - updated_count = repo.batch_update_memory_config_id_by_workspace( - workspace_id=workspace_id, + updated_count = repo.batch_update_memory_config_id_by_app( + app_id=app_id, memory_config_id=memory_config_id ) @@ -1879,8 +1879,8 @@ class AppService: if memory_config_id: app = self.db.query(App).filter(App.id == app_id).first() if app: - updated_count = self._update_endusers_memory_config_by_workspace( - app.workspace_id, memory_config_id + updated_count = self._update_endusers_memory_config_by_app( + app_id, memory_config_id ) logger.info( f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, " @@ -2016,7 +2016,7 @@ class AppService: if memory_config_id: - updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id) + updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id) logger.info( f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " f"memory_config_id={memory_config_id}, updated_count={updated_count}" From 533000030fbd4b409bc9557b3ebbf83cb4535a48 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:16:14 +0800 Subject: [PATCH 095/120] Revert "fix(memory,task): add Redis fair lock for ordered memory writes" --- .../core/memory/llm_tools/openai_client.py | 2 +- api/app/tasks.py | 40 ++---- api/app/utils/redis_lock.py | 133 +++--------------- 3 files changed, 30 insertions(+), 145 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index c70fef5f..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.debug(f"OpenAI 客户端初始化完成: type={type_}") + logger.info(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index 0e909fcc..d5f09a29 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import os import re import shutil @@ -37,10 +38,12 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema +from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService +from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisFairLock +from app.utils.redis_lock import RedisLock logger = get_logger(__name__) @@ -1145,28 +1148,8 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result - redis_client = get_sync_redis_client() - lock = None - if redis_client is not None: - lock = RedisFairLock( - key=f"memory_write:{end_user_id}", - redis_client=redis_client, - expire=120, - timeout=300, - auto_renewal=True, - ) - if not lock.acquire(): - logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") - return { - "status": "SKIPPED", - "error": "acquire lock timeout", - "end_user_id": end_user_id, - "config_id": str(config_id), - "elapsed_time": time.time() - start_time, - "task_id": self.request.id, - } - try: + # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1175,6 +1158,7 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1215,12 +1199,9 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } - finally: - if lock is not None: - try: - lock.release() - except Exception as e: - logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") + + +# unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. @@ -2898,6 +2879,3 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } - - -# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index a86ba46e..99f62d84 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,7 +1,6 @@ import redis import uuid import time -import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -11,136 +10,45 @@ else end """ -RENEW_SCRIPT = """ -if redis.call("get", KEYS[1]) == ARGV[1] then - return redis.call("expire", KEYS[1], ARGV[2]) -else - return 0 -end -""" -CLEANUP_DEAD_HEAD_SCRIPT = """ -local queue_key = KEYS[1] -local lock_key = KEYS[2] - -local first = redis.call("lindex", queue_key, 0) -if not first then - return 0 -end - -if redis.call("exists", lock_key) == 1 then - return 0 -end - -redis.call("lpop", queue_key) -return 1 -""" - -SAFE_RELEASE_QUEUE_SCRIPT = """ -local queue_key = KEYS[1] -local value = ARGV[1] - -local first = redis.call("lindex", queue_key, 0) -if first == value then - redis.call("lpop", queue_key) - return 1 -end -return 0 -""" - - -def _ensure_str(val): - """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" - if val is None: - return None - if isinstance(val, bytes): - return val.decode("utf-8") - return str(val) - - -class RedisFairLock: +class RedisLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 30, - retry_interval: float = 0.05, - timeout: float = 600, - auto_renewal: bool = True + expire: int = 60, + retry_interval: float = 0.1, + timeout: float = 30 + ): self.key = key - self.queue_key = f"{key}:queue" - self.value = str(uuid.uuid4()) self.expire = expire + self.value = str(uuid.uuid4()) + self._locked = False self.retry_interval = retry_interval self.timeout = timeout - self.redis = redis_client - self._locked = False - self.auto_renewal = auto_renewal - self._renew_thread = None - self._stop_renew = threading.Event() + self.redis_client = redis_client - def acquire(self): + def acquire(self) -> bool: start = time.time() - - self.redis.rpush(self.queue_key, self.value) - while True: - first = _ensure_str(self.redis.lindex(self.queue_key, 0)) - - if first == self.value: - ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) - if ok: - self._locked = True - - if self.auto_renewal: - self._start_renewal() - return True - - if first: - self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) - - if time.time() - start > self.timeout: - self.redis.lrem(self.queue_key, 0, self.value) + ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) + if ok: + self._locked = True + return True + if time.time() - start >= self.timeout: return False - time.sleep(self.retry_interval) - def _renewal_loop(self): - while not self._stop_renew.is_set(): - time.sleep(self.expire / 3) - if self._stop_renew.is_set(): - break - - self.redis.eval( - RENEW_SCRIPT, - 1, - self.key, - self.value, - str(self.expire) - ) - - def _start_renewal(self): - self._stop_renew = threading.Event() - self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) - self._renew_thread.start() - - def _stop_renewal(self): - self._stop_renew.set() - if self._renew_thread: - self._renew_thread.join(timeout=1) - def release(self): if not self._locked: return - - if self.auto_renewal: - self._stop_renewal() - - self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) - - self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) - + self.redis_client.eval( + UNLOCK_SCRIPT, + 1, + self.key, + self.value + ) self._locked = False def __enter__(self): @@ -151,4 +59,3 @@ class RedisFairLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() - From 8dd24533bfb0a7738528fcb5ea99b38ebf467947 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 16:00:49 +0800 Subject: [PATCH 096/120] fix(memory,task): add Redis fair lock for ordered memory writes --- .../core/memory/llm_tools/openai_client.py | 2 +- api/app/tasks.py | 40 ++++-- api/app/utils/redis_lock.py | 133 +++++++++++++++--- 3 files changed, 145 insertions(+), 30 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..c70fef5f 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.info(f"OpenAI 客户端初始化完成: type={type_}") + logger.debug(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index d5f09a29..0e909fcc 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,4 @@ import asyncio -import hashlib import os import re import shutil @@ -38,12 +37,10 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema -from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService -from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisLock +from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) @@ -1148,8 +1145,28 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result + redis_client = get_sync_redis_client() + lock = None + if redis_client is not None: + lock = RedisFairLock( + key=f"memory_write:{end_user_id}", + redis_client=redis_client, + expire=120, + timeout=300, + auto_renewal=True, + ) + if not lock.acquire(): + logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") + return { + "status": "SKIPPED", + "error": "acquire lock timeout", + "end_user_id": end_user_id, + "config_id": str(config_id), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + try: - # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1158,7 +1175,6 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1199,9 +1215,12 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } - - -# unused task + finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. @@ -2879,3 +2898,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index 99f62d84..a86ba46e 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,6 +1,7 @@ import redis import uuid import time +import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -10,45 +11,136 @@ else end """ +RENEW_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) +else + return 0 +end +""" -class RedisLock: +CLEANUP_DEAD_HEAD_SCRIPT = """ +local queue_key = KEYS[1] +local lock_key = KEYS[2] + +local first = redis.call("lindex", queue_key, 0) +if not first then + return 0 +end + +if redis.call("exists", lock_key) == 1 then + return 0 +end + +redis.call("lpop", queue_key) +return 1 +""" + +SAFE_RELEASE_QUEUE_SCRIPT = """ +local queue_key = KEYS[1] +local value = ARGV[1] + +local first = redis.call("lindex", queue_key, 0) +if first == value then + redis.call("lpop", queue_key) + return 1 +end +return 0 +""" + + +def _ensure_str(val): + """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" + if val is None: + return None + if isinstance(val, bytes): + return val.decode("utf-8") + return str(val) + + +class RedisFairLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 60, - retry_interval: float = 0.1, - timeout: float = 30 - + expire: int = 30, + retry_interval: float = 0.05, + timeout: float = 600, + auto_renewal: bool = True ): self.key = key - self.expire = expire + self.queue_key = f"{key}:queue" self.value = str(uuid.uuid4()) - self._locked = False + self.expire = expire self.retry_interval = retry_interval self.timeout = timeout - self.redis_client = redis_client + self.redis = redis_client + self._locked = False + self.auto_renewal = auto_renewal + self._renew_thread = None + self._stop_renew = threading.Event() - def acquire(self) -> bool: + def acquire(self): start = time.time() + + self.redis.rpush(self.queue_key, self.value) + while True: - ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) - if ok: - self._locked = True - return True - if time.time() - start >= self.timeout: + first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + + if first == self.value: + ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) + if ok: + self._locked = True + + if self.auto_renewal: + self._start_renewal() + return True + + if first: + self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + + if time.time() - start > self.timeout: + self.redis.lrem(self.queue_key, 0, self.value) return False + time.sleep(self.retry_interval) + def _renewal_loop(self): + while not self._stop_renew.is_set(): + time.sleep(self.expire / 3) + if self._stop_renew.is_set(): + break + + self.redis.eval( + RENEW_SCRIPT, + 1, + self.key, + self.value, + str(self.expire) + ) + + def _start_renewal(self): + self._stop_renew = threading.Event() + self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) + self._renew_thread.start() + + def _stop_renewal(self): + self._stop_renew.set() + if self._renew_thread: + self._renew_thread.join(timeout=1) + def release(self): if not self._locked: return - self.redis_client.eval( - UNLOCK_SCRIPT, - 1, - self.key, - self.value - ) + + if self.auto_renewal: + self._stop_renewal() + + self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) + + self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._locked = False def __enter__(self): @@ -59,3 +151,4 @@ class RedisLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + From 9d91453200382dc69f36cbb7173ee3ad277ef518 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 30 Mar 2026 17:28:13 +0800 Subject: [PATCH 097/120] fix(mcp): Addressing the issue of asynchronous connections for the MCP --- api/app/core/tools/mcp/client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 6df6df51..b437d021 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -99,7 +99,7 @@ class SimpleMCPClient: # 建立 SSE 连接 response = await self._session.get(self.server_url) - if response.status != 200: + if response.status not in (200, 202): error_text = await response.text() raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") @@ -190,7 +190,9 @@ class SimpleMCPClient: try: async with self._session.post(self._endpoint_url, json=request) as response: - if response.status != 200: + # MCP SSE 协议:POST 请求返回 200 或 202 均为正常 + # 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回 + if response.status not in (200, 202): error_text = await response.text() raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") @@ -205,7 +207,7 @@ class SimpleMCPClient: raise MCPConnectionError("endpoint URL 未初始化") async with self._session.post(self._endpoint_url, json=notification) as response: - if response.status != 200: + if response.status not in (200, 202): logger.warning(f"通知发送失败: {response.status}") async def _initialize_modelscope_session(self): From e15af5a2ba6b7d4fd1f9be466742d574cbe001eb Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 16:07:15 +0800 Subject: [PATCH 098/120] [fix] Create a complete index --- api/app/core/memory/agent/utils/write_tools.py | 4 ++-- api/app/repositories/neo4j/create_indexes.py | 12 ------------ 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 55bcb8ba..abbcc54d 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -151,9 +151,9 @@ async def write( # Step 3: Save all data to Neo4j database step_start = time.time() - from app.repositories.neo4j.create_indexes import create_fulltext_indexes + from app.repositories.neo4j.create_indexes import create_all_indexes try: - await create_fulltext_indexes() + await create_all_indexes() except Exception as e: logger.error(f"Error creating indexes: {e}", exc_info=True) diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index d9e94117..a10ee9a1 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -144,18 +144,6 @@ async def create_vector_indexes(): """) print("✓ Created: dialogue_embedding_index") - # Community summary embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS - FOR (c:Community) - ON c.summary_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: community_summary_embedding_index") - print("\nVector indexes created successfully!") print("\nExpected performance improvement:") print(" Before: ~1.4s for embedding search") From d42db0ca33d94238274c97dc4f8eeeb38a406560 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 17:36:37 +0800 Subject: [PATCH 099/120] [fix] Delete the index creation for the "config_id" field --- api/app/repositories/neo4j/create_indexes.py | 48 -------------------- 1 file changed, 48 deletions(-) diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index a10ee9a1..334897e2 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -155,54 +155,6 @@ async def create_vector_indexes(): await connector.close() -async def create_config_id_indexes(): - """Create indexes on config_id fields for improved query performance. - - These indexes enable fast filtering of nodes by configuration ID, - which is essential for configuration isolation and multi-tenant scenarios. - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Config ID Indexes") - print("=" * 70) - - # Dialogue.config_id index - await connector.execute_query(""" - CREATE INDEX dialogue_config_id_index IF NOT EXISTS - FOR (d:Dialogue) ON (d.config_id) - """) - print("✓ Created: dialogue_config_id_index") - - # Statement.config_id index - await connector.execute_query(""" - CREATE INDEX statement_config_id_index IF NOT EXISTS - FOR (s:Statement) ON (s.config_id) - """) - print("✓ Created: statement_config_id_index") - - # ExtractedEntity.config_id index - await connector.execute_query(""" - CREATE INDEX entity_config_id_index IF NOT EXISTS - FOR (e:ExtractedEntity) ON (e.config_id) - """) - print("✓ Created: entity_config_id_index") - - # MemorySummary.config_id index - await connector.execute_query(""" - CREATE INDEX summary_config_id_index IF NOT EXISTS - FOR (m:MemorySummary) ON (m.config_id) - """) - print("✓ Created: summary_config_id_index") - - print("\nConfig ID indexes created successfully!") - print("These indexes enable fast filtering by configuration ID.") - - except Exception as e: - print(f"✗ Error creating config_id indexes: {e}") - finally: - await connector.close() - async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. From 052c7c19b3699c83902e122ac8eadd519bef323f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 17:42:47 +0800 Subject: [PATCH 100/120] [fix] Avoid unnecessary index creation costs --- api/app/core/memory/agent/utils/write_tools.py | 5 ----- api/app/main.py | 9 +++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index abbcc54d..3af9326e 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -151,11 +151,6 @@ async def write( # Step 3: Save all data to Neo4j database step_start = time.time() - from app.repositories.neo4j.create_indexes import create_all_indexes - try: - await create_all_indexes() - except Exception as e: - logger.error(f"Error creating indexes: {e}", exc_info=True) # 添加死锁重试机制 max_retries = 3 diff --git a/api/app/main.py b/api/app/main.py index f4c23ca8..2fdf40b6 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -62,6 +62,15 @@ async def lifespan(app: FastAPI): logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("应用程序启动完成") + + # 初始化 Neo4j 索引和约束(仅启动时执行一次) + try: + from app.repositories.neo4j.create_indexes import create_all_indexes + await create_all_indexes() + logger.info("Neo4j 索引和约束初始化完成") + except Exception as e: + logger.warning(f"Neo4j 索引初始化失败(服务仍可启动,但查询性能可能受影响): {e}") + yield # 应用关闭事件 logger.info("应用程序正在关闭") From 83774d744392ce024ba2a0074e56bf246ab67bac Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 30 Mar 2026 18:09:35 +0800 Subject: [PATCH 101/120] feat: optimize app log controller code structure --- api/app/controllers/app_log_controller.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index ac0b2ac4..adf90ca4 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -25,14 +25,12 @@ def list_app_logs( app_id: uuid.UUID, page: int = Query(1, ge=1), pagesize: int = Query(20, ge=1, le=100), - user_id: Optional[str] = None, is_draft: Optional[bool] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): """查看应用下所有会话记录(分页) - - 支持按 user_id 筛选 - 支持按 is_draft 筛选(草稿会话 / 发布会话) - 按最新更新时间倒序排列 - 所有人(包括共享者和被共享者)都只能查看自己的会话记录 @@ -48,12 +46,6 @@ def list_app_logs( Conversation.workspace_id == workspace_id, Conversation.is_active.is_(True), ) - - # 所有人只能查看自己的会话记录 - stmt = stmt.where(Conversation.user_id == str(current_user.id)) - - if user_id: - stmt = stmt.where(Conversation.user_id == user_id) if is_draft is not None: stmt = stmt.where(Conversation.is_draft == is_draft) @@ -105,7 +97,6 @@ def get_app_log_detail( Conversation.app_id == app_id, Conversation.workspace_id == workspace_id, Conversation.is_active.is_(True), - Conversation.user_id == str(current_user.id), ) ).first() From 418114ef72c34cc6491749ca354a48a0691b8534 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 18:14:31 +0800 Subject: [PATCH 102/120] [fix] Modify Index Creation --- api/app/main.py | 10 +- api/app/repositories/neo4j/create_indexes.py | 171 ++----------------- 2 files changed, 18 insertions(+), 163 deletions(-) diff --git a/api/app/main.py b/api/app/main.py index 2fdf40b6..9e501f11 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,5 +1,6 @@ import os import subprocess +from app.repositories.neo4j.create_indexes import create_all_indexes from contextlib import asynccontextmanager from fastapi import FastAPI, APIRouter @@ -60,16 +61,9 @@ async def lifespan(app: FastAPI): logger.warning(f"加载预定义模型时出错: {str(e)}") else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") - + await create_all_indexes() logger.info("应用程序启动完成") - # 初始化 Neo4j 索引和约束(仅启动时执行一次) - try: - from app.repositories.neo4j.create_indexes import create_all_indexes - await create_all_indexes() - logger.info("Neo4j 索引和约束初始化完成") - except Exception as e: - logger.warning(f"Neo4j 索引初始化失败(服务仍可启动,但查询性能可能受影响): {e}") yield # 应用关闭事件 diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 334897e2..5132aa09 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -1,62 +1,47 @@ +import asyncio from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - async def create_fulltext_indexes(): """Create full-text indexes for keyword search with BM25 scoring.""" connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Full-Text Indexes (for keyword search)") - print("=" * 70) + # 创建 Statements 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: statementsFulltext") + """) # # 创建 Dialogues 索引 # await connector.execute_query(""" # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } # """) - # 创建 Entities 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: entitiesFulltext") + """) # 创建 Chunks 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: chunksFulltext") + """) # 创建 MemorySummary 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: summariesFulltext") - + """) # 创建 Community 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) - print("✓ Created: communitiesFulltext") - print("\nFull-text indexes created successfully with BM25 support.") - except Exception as e: - print(f"✗ Error creating full-text indexes: {e}") finally: await connector.close() - - async def create_vector_indexes(): """Create vector indexes for fast embedding similarity search. @@ -65,12 +50,7 @@ async def create_vector_indexes(): """ connector = Neo4jConnector() try: - print("\n" + "=" * 70) - print("Creating Vector Indexes (for embedding search)") - print("=" * 70) - print("Note: Adjust vector.dimensions if using different embedding model") - print(" Current setting: 1024 dimensions (for bge-m3)") - print() + # Statement embedding index await connector.execute_query(""" @@ -82,7 +62,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: statement_embedding_index") + # Chunk embedding index await connector.execute_query(""" @@ -94,7 +74,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: chunk_embedding_index") + # Entity name embedding index await connector.execute_query(""" @@ -106,7 +86,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: entity_embedding_index") + # Memory summary embedding index await connector.execute_query(""" @@ -118,8 +98,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: summary_embedding_index") - + # Community summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS @@ -129,8 +108,7 @@ async def create_vector_indexes(): `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """) - print("✓ Created: community_summary_embedding_index") + """) # Dialogue embedding index (optional) await connector.execute_query(""" @@ -142,31 +120,15 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - print("✓ Created: dialogue_embedding_index") - - print("\nVector indexes created successfully!") - print("\nExpected performance improvement:") - print(" Before: ~1.4s for embedding search") - print(" After: ~0.05-0.2s for embedding search (10-30x faster!)") - except Exception as e: - print(f"✗ Error creating vector indexes: {e}") finally: await connector.close() - - - async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. - Ensures concurrent MERGE operations remain safe and prevents duplicates. """ connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Unique Constraints") - print("=" * 70) - + try: # Dialogue.id unique await connector.execute_query( """ @@ -174,8 +136,7 @@ async def create_unique_constraints(): FOR (d:Dialogue) REQUIRE d.id IS UNIQUE """ ) - print("✓ Created: dialog_id_unique") - + # Statement.id unique await connector.execute_query( """ @@ -183,8 +144,7 @@ async def create_unique_constraints(): FOR (s:Statement) REQUIRE s.id IS UNIQUE """ ) - print("✓ Created: statement_id_unique") - + # Chunk.id unique await connector.execute_query( """ @@ -192,112 +152,13 @@ async def create_unique_constraints(): FOR (c:Chunk) REQUIRE c.id IS UNIQUE """ ) - print("✓ Created: chunk_id_unique") - - print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.") - except Exception as e: - print(f"✗ Error creating unique constraints: {e}") + finally: await connector.close() - - async def create_all_indexes(): """Create all indexes and constraints in one go.""" - print("\n" + "=" * 70) - print("Neo4j Index & Constraint Setup") - print("=" * 70) - print("This will create:") - print(" 1. Full-text indexes (for keyword/BM25 search)") - print(" 2. Vector indexes (for embedding similarity search)") - print(" 3. Config ID indexes (for configuration isolation)") - print(" 4. Unique constraints (for data integrity)") - print("=" * 70) - await create_fulltext_indexes() await create_vector_indexes() - await create_config_id_indexes() await create_unique_constraints() - - print("\n" + "=" * 70) print("✓ All indexes and constraints created successfully!") - print("=" * 70) - print("\nTo verify, run in Neo4j Browser:") - print(" SHOW INDEXES") - print(" SHOW CONSTRAINTS") - print() - - -async def check_indexes(): - """Check what indexes currently exist.""" - connector = Neo4jConnector() - - try: - print("\n" + "=" * 70) - print("Checking Existing Indexes") - print("=" * 70) - query = "SHOW INDEXES" - result = await connector.execute_query(query) - - fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT'] - vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR'] - range_indexes = [idx for idx in result if idx.get('type') == 'RANGE'] - - print(f"\nFull-text indexes: {len(fulltext_indexes)}") - for idx in fulltext_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nVector indexes: {len(vector_indexes)}") - for idx in vector_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nRange indexes (including config_id): {len(range_indexes)}") - for idx in range_indexes: - print(f" ✓ {idx.get('name')}") - - if not vector_indexes: - print("\n⚠️ WARNING: No vector indexes found!") - print(" Embedding search will be VERY SLOW (~1.4s)") - print(" Run: python create_indexes.py") - - # Check for config_id indexes - config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')] - if len(config_id_indexes) < 4: - print("\n⚠️ WARNING: Not all config_id indexes found!") - print(f" Expected 4, found {len(config_id_indexes)}") - print(" Run: python create_indexes.py config_id") - - print("=" * 70) - - finally: - await connector.close() - - -if __name__ == "__main__": - import asyncio - import sys - - if len(sys.argv) > 1: - command = sys.argv[1] - if command == "check": - asyncio.run(check_indexes()) - elif command == "fulltext": - asyncio.run(create_fulltext_indexes()) - elif command == "vector": - asyncio.run(create_vector_indexes()) - elif command == "config_id": - asyncio.run(create_config_id_indexes()) - elif command == "constraints": - asyncio.run(create_unique_constraints()) - else: - print(f"Unknown command: {command}") - print("\nUsage:") - print(" python create_indexes.py # Create all indexes") - print(" python create_indexes.py check # Check existing indexes") - print(" python create_indexes.py fulltext # Create only full-text indexes") - print(" python create_indexes.py vector # Create only vector indexes") - print(" python create_indexes.py config_id # Create only config_id indexes") - print(" python create_indexes.py constraints # Create only constraints") - else: - asyncio.run(create_all_indexes()) - From c90b58bbcd891e95ab2b5ddeb38819c9dfe132ab Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 17:05:59 +0800 Subject: [PATCH 103/120] [fix] The "write_tools" module actively shuts down the client, and it closes before the task event loop is completed. --- .../core/memory/agent/utils/write_tools.py | 16 ++++++++ api/app/tasks.py | 41 ++++++++++++++++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 3af9326e..1f437973 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -274,5 +274,21 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + # Close LLM/Embedder underlying httpx clients to prevent + # 'RuntimeError: Event loop is closed' during garbage collection + for client_obj in (llm_client, embedder_client): + try: + underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None) + if underlying is None: + continue + # Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model + inner = getattr(underlying, '_model', underlying) + # LangChain OpenAI models expose async_client (httpx.AsyncClient) + http_client = getattr(inner, 'async_client', None) + if http_client is not None and hasattr(http_client, 'aclose'): + await http_client.aclose() + except Exception: + pass + logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/tasks.py b/api/app/tasks.py index 0e909fcc..b7826332 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -101,7 +101,11 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: def set_asyncio_event_loop(): - """Set the asyncio event loop for the current thread.""" + """Set the asyncio event loop for the current thread. + + Always creates a fresh event loop to avoid 'Event loop is closed' errors + caused by stale httpx.AsyncClient objects from previous task runs. + """ try: loop = asyncio.get_event_loop() if loop.is_closed(): @@ -113,6 +117,30 @@ def set_asyncio_event_loop(): return loop +def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop): + """Gracefully shutdown pending async generators and tasks on the event loop. + + This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + by giving pending aclose() coroutines a chance to run before the loop is discarded. + """ + try: + # Cancel and collect all remaining tasks + all_tasks = asyncio.all_tasks(loop) + if all_tasks: + for task in all_tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True)) + # Shutdown async generators (triggers __aclose__ on httpx clients etc.) + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception: + pass + finally: + loop.close() + # Set a new event loop so subsequent tasks get a fresh one + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + + @celery_app.task(name="tasks.process_item") def process_item(item: dict): """ @@ -1216,11 +1244,12 @@ def write_message_task( "task_id": self.request.id } finally: - if lock is not None: - try: - lock.release() - except Exception as e: - logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") + # Gracefully shutdown the event loop to prevent + # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ + _shutdown_loop_gracefully(loop) + + +# unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. From 4974f9aa98568b89cbc21e2ac85618ab4e006919 Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 30 Mar 2026 18:27:44 +0800 Subject: [PATCH 104/120] refactor: extract app log SQL queries to Service and Repository layers --- api/app/controllers/app_log_controller.py | 81 +++-------- .../repositories/conversation_repository.py | 118 ++++++++++++++++ api/app/services/app_log_service.py | 128 ++++++++++++++++++ 3 files changed, 268 insertions(+), 59 deletions(-) create mode 100644 api/app/services/app_log_service.py diff --git a/api/app/controllers/app_log_controller.py b/api/app/controllers/app_log_controller.py index adf90ca4..92b5becd 100644 --- a/api/app/controllers/app_log_controller.py +++ b/api/app/controllers/app_log_controller.py @@ -3,17 +3,16 @@ import uuid from typing import Optional from fastapi import APIRouter, Depends, Query -from sqlalchemy import select, desc, func from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user, cur_workspace_access_guard -from app.models.conversation_model import Conversation, Message -from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage +from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail from app.schemas.response_schema import PageData, PageMeta from app.services.app_service import AppService +from app.services.app_log_service import AppLogService router = APIRouter(prefix="/apps", tags=["App Logs"]) logger = get_business_logger() @@ -38,35 +37,22 @@ def list_app_logs( workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - stmt = select(Conversation).where( - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), + # 使用 Service 层查询 + log_service = AppLogService(db) + conversations, total = log_service.list_conversations( + app_id=app_id, + workspace_id=workspace_id, + page=page, + pagesize=pagesize, + is_draft=is_draft ) - if is_draft is not None: - stmt = stmt.where(Conversation.is_draft == is_draft) - - total = int(db.execute( - select(func.count()).select_from(stmt.subquery()) - ).scalar_one()) - - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - - conversations = list(db.scalars(stmt).all()) - items = [AppLogConversation.model_validate(c) for c in conversations] meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) - logger.info( - "查询应用日志会话列表", - extra={"app_id": str(app_id), "total": total, "page": page} - ) - return success(data=PageData(page=meta, items=items)) @@ -87,40 +73,17 @@ def get_app_log_detail( workspace_id = current_user.current_workspace_id # 验证应用访问权限 - service = AppService(db) - service.get_app(app_id, workspace_id) + app_service = AppService(db) + app_service.get_app(app_id, workspace_id) - # 查询会话(确保属于该应用和工作空间) - conversation = db.scalars( - select(Conversation).where( - Conversation.id == conversation_id, - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active.is_(True), - ) - ).first() - - if not conversation: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("会话", str(conversation_id)) - - # 查询消息(按时间正序) - messages = list(db.scalars( - select(Message) - .where(Message.conversation_id == conversation_id) - .order_by(Message.created_at) - ).all()) - - detail = AppLogConversationDetail.model_validate(conversation) - detail.messages = [AppLogMessage.model_validate(m) for m in messages] - - logger.info( - "查询应用日志会话详情", - extra={ - "app_id": str(app_id), - "conversation_id": str(conversation_id), - "message_count": len(messages) - } + # 使用 Service 层查询 + log_service = AppLogService(db) + conversation = log_service.get_conversation_detail( + app_id=app_id, + conversation_id=conversation_id, + workspace_id=workspace_id ) + detail = AppLogConversationDetail.model_validate(conversation) + return success(data=detail) diff --git a/api/app/repositories/conversation_repository.py b/api/app/repositories/conversation_repository.py index 90f2d6ec..0676a255 100644 --- a/api/app/repositories/conversation_repository.py +++ b/api/app/repositories/conversation_repository.py @@ -199,6 +199,96 @@ class ConversationRepository: ) return conversations, total + def list_app_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + is_draft: Optional[bool] = None, + page: int = 1, + pagesize: int = 20 + ) -> tuple[list[Conversation], int]: + """ + 查询应用日志会话列表(带分页和过滤) + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + is_draft: 是否草稿会话(None 表示不过滤) + page: 页码(从 1 开始) + pagesize: 每页数量 + + Returns: + Tuple[List[Conversation], int]: (会话列表,总数) + """ + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + # Calculate total number of records + total = int(self.db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + # Apply pagination + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(self.db.scalars(stmt).all()) + + logger.info( + "Listed app conversations successfully", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "returned": len(conversations), + "total": total + } + ) + return conversations, total + + def get_conversation_for_app_log( + self, + conversation_id: uuid.UUID, + app_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询应用日志的会话详情 + + Args: + conversation_id: 会话 ID + app_id: 应用 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info(f"Fetching conversation for app log: {conversation_id}") + + stmt = select(Conversation).where( + Conversation.id == conversation_id, + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + conversation = self.db.scalars(stmt).first() + + if not conversation: + logger.warning(f"Conversation not found: {conversation_id}") + raise ResourceNotFoundException("会话", str(conversation_id)) + + logger.info(f"Conversation fetched successfully: {conversation_id}") + return conversation + def soft_delete_conversation_by_conversation_id( self, conversation_id: uuid.UUID, @@ -290,6 +380,34 @@ class MessageRepository: self.db.add(message) return message + def get_messages_by_conversation( + self, + conversation_id: uuid.UUID + ) -> list[Message]: + """ + 查询会话的所有消息(按时间正序) + + Args: + conversation_id: 会话 ID + + Returns: + List[Message]: 消息列表 + """ + stmt = select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at) + + messages = list(self.db.scalars(stmt).all()) + + logger.info( + "Fetched messages for conversation", + extra={ + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + return messages + def get_message_by_conversation_id( self, conversation_id: uuid.UUID, diff --git a/api/app/services/app_log_service.py b/api/app/services/app_log_service.py new file mode 100644 index 00000000..856045d1 --- /dev/null +++ b/api/app/services/app_log_service.py @@ -0,0 +1,128 @@ +"""应用日志服务层""" +import uuid +from typing import Optional, Tuple +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.models.conversation_model import Conversation, Message +from app.repositories.conversation_repository import ConversationRepository, MessageRepository + +logger = get_business_logger() + + +class AppLogService: + """应用日志服务""" + + def __init__(self, db: Session): + self.db = db + self.conversation_repository = ConversationRepository(db) + self.message_repository = MessageRepository(db) + + def list_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + page: int = 1, + pagesize: int = 20, + is_draft: Optional[bool] = None, + ) -> Tuple[list[Conversation], int]: + """ + 查询应用日志会话列表 + + Args: + app_id: 应用 ID + workspace_id: 工作空间 ID + page: 页码(从 1 开始) + pagesize: 每页数量 + is_draft: 是否草稿会话(None 表示不过滤) + + Returns: + Tuple[list[Conversation], int]: (会话列表,总数) + """ + logger.info( + "查询应用日志会话列表", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "page": page, + "pagesize": pagesize, + "is_draft": is_draft + } + ) + + # 使用 Repository 查询 + conversations, total = self.conversation_repository.list_app_conversations( + app_id=app_id, + workspace_id=workspace_id, + is_draft=is_draft, + page=page, + pagesize=pagesize + ) + + logger.info( + "查询应用日志会话列表成功", + extra={ + "app_id": str(app_id), + "total": total, + "returned": len(conversations) + } + ) + + return conversations, total + + def get_conversation_detail( + self, + app_id: uuid.UUID, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID + ) -> Conversation: + """ + 查询会话详情(包含消息) + + Args: + app_id: 应用 ID + conversation_id: 会话 ID + workspace_id: 工作空间 ID + + Returns: + Conversation: 包含消息的会话对象 + + Raises: + ResourceNotFoundException: 当会话不存在时 + """ + logger.info( + "查询应用日志会话详情", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "workspace_id": str(workspace_id) + } + ) + + # 查询会话 + conversation = self.conversation_repository.get_conversation_for_app_log( + conversation_id=conversation_id, + app_id=app_id, + workspace_id=workspace_id + ) + + # 查询消息(按时间正序) + messages = self.message_repository.get_messages_by_conversation( + conversation_id=conversation_id + ) + + # 将消息附加到会话对象 + conversation.messages = messages + + logger.info( + "查询应用日志会话详情成功", + extra={ + "app_id": str(app_id), + "conversation_id": str(conversation_id), + "message_count": len(messages) + } + ) + + return conversation From 0c677701c0f37f8ad6b3b008b10cf05c12b9fd25 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 18:29:17 +0800 Subject: [PATCH 105/120] [fix] iron release --- api/app/tasks.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/api/app/tasks.py b/api/app/tasks.py index b7826332..73b001e2 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1244,6 +1244,11 @@ def write_message_task( "task_id": self.request.id } finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # Gracefully shutdown the event loop to prevent # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ _shutdown_loop_gracefully(loop) From 6e7c641fd4965045f051ae9b375c5eab4da66f63 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 18:46:25 +0800 Subject: [PATCH 106/120] [fix] Remove duplicate creations --- api/app/tasks.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index 73b001e2..4928ca7f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -101,10 +101,11 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: def set_asyncio_event_loop(): - """Set the asyncio event loop for the current thread. + """Ensure an open asyncio event loop exists for the current thread. - Always creates a fresh event loop to avoid 'Event loop is closed' errors - caused by stale httpx.AsyncClient objects from previous task runs. + Reuses the existing event loop if one is available and still open. + Creates and installs a new event loop only when the current one is + closed or missing (e.g. after ``_shutdown_loop_gracefully``). """ try: loop = asyncio.get_event_loop() @@ -122,6 +123,9 @@ def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop): This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ by giving pending aclose() coroutines a chance to run before the loop is discarded. + + Note: This only tears down the given loop. Callers that need a fresh event + loop afterwards should use ``set_asyncio_event_loop()`` explicitly. """ try: # Cancel and collect all remaining tasks @@ -136,9 +140,6 @@ def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop): pass finally: loop.close() - # Set a new event loop so subsequent tasks get a fresh one - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) @celery_app.task(name="tasks.process_item") From 3419bb137a1933644f1ea8dc1329cada916af12e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 19:56:02 +0800 Subject: [PATCH 107/120] [fix] Fix the alias query statement --- api/app/repositories/neo4j/cypher_queries.py | 23 ++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c08f9d0e..26ffe350 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WITH e, score -UNION -MATCH (e:ExtractedEntity) -WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) - AND e.aliases IS NOT NULL - AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) -WITH e, +WITH collect({entity: e, score: score}) AS fulltextResults + +OPTIONAL MATCH (ae:ExtractedEntity) +WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id) + AND ae.aliases IS NOT NULL + AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q)) +WITH fulltextResults, collect(ae) AS aliasEntities + +UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score: CASE - WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 - WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 + WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 ELSE 0.8 - END AS score + END +}]) AS row +WITH row.entity AS e, row.score AS score WITH DISTINCT e, MAX(score) AS score OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) From b7198f1abd7c66a3d7aef6a30426ce434be2f592 Mon Sep 17 00:00:00 2001 From: wxy Date: Mon, 30 Mar 2026 20:08:12 +0800 Subject: [PATCH 108/120] fix: allow shared users to view request logs for their own API keys --- api/app/controllers/service/app_api_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 32a911f9..d4573464 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -91,7 +91,7 @@ async def chat( app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) other_id = payload.user_id - workspace_id = app.workspace_id + workspace_id = api_key_auth.workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=app.id, From dbe387f666f28fe3625d91f219911d579ea4d5b1 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Mon, 30 Mar 2026 20:53:17 +0800 Subject: [PATCH 109/120] fix(tasks): increase redis lock timeout and expiration for write_message_task - Increase lock expiration time from 120 to 600 seconds (5 minutes) - Increase lock timeout from 300 to 3600 seconds (1 hour) - Prevents premature lock release during long-running memory write operations --- api/app/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index 4928ca7f..72421a5f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1180,8 +1180,8 @@ def write_message_task( lock = RedisFairLock( key=f"memory_write:{end_user_id}", redis_client=redis_client, - expire=120, - timeout=300, + expire=600, + timeout=3600, auto_renewal=True, ) if not lock.acquire(): From abc27c837213df859f3a743edeffa81b5af99c20 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 30 Mar 2026 21:17:21 +0800 Subject: [PATCH 110/120] [fix] Add the function for judging the event loop switch --- api/app/aioRedis.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index f79ef0e1..357533ad 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -33,20 +33,24 @@ _thread_local = threading.local() def get_thread_safe_redis() -> redis.StrictRedis: - """Get a Redis client safe for the current execution context. - - Uses thread-local storage with PID checking to ensure: - - Each thread gets its own ConnectionPool (Celery --pool=threads) - - Pools are recreated after fork (Celery --pool=prefork) - - health_check_interval prevents stale connection errors - - Returns: - redis.StrictRedis: A Redis client with a thread/process-local pool. + """Return a Redis client whose connection pool is bound to the current + thread, process **and** event loop. + + The pool is recreated when: + - The PID changes (fork, Celery --pool=prefork) + - The thread has no pool yet (Celery --pool=threads) + - The previously-cached event loop has been closed (Celery tasks call + ``_shutdown_loop_gracefully`` which closes the loop after each run) """ current_pid = os.getpid() + cached_loop = getattr(_thread_local, "loop", None) + loop_stale = cached_loop is not None and cached_loop.is_closed() - if not hasattr(_thread_local, "pool") or getattr(_thread_local, "pid", None) != current_pid: + if not hasattr(_thread_local, "pool") \ + or getattr(_thread_local, "pid", None) != current_pid \ + or loop_stale: _thread_local.pid = current_pid + _thread_local.loop = asyncio.get_event_loop() _thread_local.pool = ConnectionPool.from_url( _REDIS_URL, db=settings.REDIS_DB, From 2d6cde157e615c7813d6baee9319593c3e3fa19f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 31 Mar 2026 09:59:39 +0800 Subject: [PATCH 111/120] [fix] No event loop is set and defensive programming is not used for non-main thread calls. --- api/app/aioRedis.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index 357533ad..dfb63dad 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -50,7 +50,12 @@ def get_thread_safe_redis() -> redis.StrictRedis: or getattr(_thread_local, "pid", None) != current_pid \ or loop_stale: _thread_local.pid = current_pid - _thread_local.loop = asyncio.get_event_loop() + # Python 3.10+: get_event_loop() raises RuntimeError in threads + # where no loop has been set yet (e.g. Celery --pool=threads). + try: + _thread_local.loop = asyncio.get_event_loop() + except RuntimeError: + _thread_local.loop = None _thread_local.pool = ConnectionPool.from_url( _REDIS_URL, db=settings.REDIS_DB, From e134b96333db809e9ce95a9d2e41f20799194085 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 12:10:28 +0800 Subject: [PATCH 112/120] fix(web): ui --- web/src/styles/index.css | 2 +- .../views/UserMemoryDetail/components/CommunityNetwork.tsx | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 71a6cce4..0f183374 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -397,7 +397,7 @@ body { background-color: #171719; } -.spin.ant-spin-nested-loading .ant-spin-container::after { +.spin .ant-spin-nested-loading .ant-spin-container::after { background: transparent; } .upload-block, diff --git a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx index 33da0b04..ccfbc14d 100644 --- a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx +++ b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx @@ -65,8 +65,8 @@ const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => voi }, [id]) if (loading) { - return - + return +
From d4450658a8f7f05606dcaf30151daa04e6b7cae9 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 13:41:46 +0800 Subject: [PATCH 113/120] fix(web): ui --- web/src/styles/index.css | 3 ++- web/src/views/MemoryConversation/index.tsx | 2 +- web/src/views/Prompt/index.tsx | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 0f183374..7e21e1af 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -377,9 +377,10 @@ body { .ant-input-filled, .ant-select-filled:not(.ant-select-customize-input) .ant-select-selector { background-color: #FFFFFF; + border-color: #FFFFFF; } .ant-input-filled:hover, -.ant-select-filled:not(.ant-select-customize-input) .ant-select-selector { +.ant-select-filled:not(.ant-select-disabled):not(.ant-select-customize-input):not(.ant-pagination-size-changer):hover .ant-select-selector { background-color: #FFFFFF; border-color: #171719; } diff --git a/web/src/views/MemoryConversation/index.tsx b/web/src/views/MemoryConversation/index.tsx index c33bd0e5..1e1d0a92 100644 --- a/web/src/views/MemoryConversation/index.tsx +++ b/web/src/views/MemoryConversation/index.tsx @@ -169,8 +169,8 @@ const MemoryConversation: FC = () => { placeholder={t('memoryConversation.searchPlaceholder')} style={{ width: '100%', marginBottom: '16px' }} onChange={setUserId} - variant="borderless" className="rb:bg-white rb:rounded-lg" + variant="filled" /> diff --git a/web/src/views/Prompt/index.tsx b/web/src/views/Prompt/index.tsx index 469b1e39..521971f9 100644 --- a/web/src/views/Prompt/index.tsx +++ b/web/src/views/Prompt/index.tsx @@ -243,6 +243,7 @@ const Prompt: FC = () => { From 2ad25c48d282916e3c855170031999fc844c9f6d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 31 Mar 2026 13:52:41 +0800 Subject: [PATCH 114/120] refactor(memory_agent_service, memory_perceptual_service): Simplify audit logger import and usage - Removed try-except block for importing `audit_logger` and directly imported it. - Removed redundant checks for `audit_logger` being `None` before logging operations. - Added a check in `MemoryPerceptualService` to return `None` if `model_config` or `llm` is `None`. --- api/app/services/memory_agent_service.py | 108 ++++++++---------- api/app/services/memory_perceptual_service.py | 13 ++- 2 files changed, 60 insertions(+), 61 deletions(-) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 289fd74c..c27a75be 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.utils.log.audit_logger import audit_logger from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import ( ) from app.services.memory_perceptual_service import MemoryPerceptualService -try: - from app.core.memory.utils.log.audit_logger import audit_logger -except ImportError: - audit_logger = None logger = get_logger(__name__) config_logger = get_config_logger() @@ -68,24 +65,22 @@ class MemoryAgentService: if str(messages) == 'success': logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 - if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=True, - duration=duration, details={"message_length": len(message)}) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, + duration=duration, details={"message_length": len(message)}) return context else: logger.warning(f"Write operation failed for group {end_user_id}") # 记录失败的操作 - if audit_logger: - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=f"写入失败: {messages[:100]}" - ) + audit_logger.log_operation( + operation="WRITE", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=f"写入失败: {messages[:100]}" + ) raise ValueError(f"写入失败: {messages}") @@ -338,10 +333,9 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -401,10 +395,10 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) async def read_memory( @@ -469,10 +463,9 @@ class MemoryAgentService: logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") # 导入审计日志记录器 - try: - from app.core.memory.utils.log.audit_logger import audit_logger - except ImportError: - audit_logger = None + + + config_load_start = time.time() try: @@ -492,16 +485,15 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) @@ -633,15 +625,15 @@ class MemoryAgentService: total_time = time.time() - start_time logger.info( f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=True, - duration=duration - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=True, + duration=duration + ) return { "answer": summary, @@ -651,16 +643,16 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 5c838fc0..7cf94a1a 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -244,6 +244,8 @@ class MemoryPerceptualService: file: FileInput ): llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + if model_config is None or llm is None: + return None multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, @@ -265,15 +267,20 @@ class MemoryPerceptualService: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') - except FileNotFoundError: - raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + except FileNotFoundError as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None messages = [ {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, {"role": RoleType.USER.value, "content": [ {"type": "text", "text": "Summarize the following file"}, file_message ]} ] - result = await llm.ainvoke(messages) + try: + result = await llm.ainvoke(messages) + except Exception as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None content = result.content final_output = "" if isinstance(content, list): From baf02e4faaa9a5b2886f7eac69182aa6f96dc470 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 15:39:06 +0800 Subject: [PATCH 115/120] fix(web): update i18n --- web/src/i18n/en.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 9b957a84..4e631b26 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1827,6 +1827,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re memoryTipTitle: 'Are you sure you want to enable conversation memory? Conversations will be saved to the memory store.', stopAudioRecorder: 'Stop Recording', startAudioRecorder: 'Start Recording', + citations: 'Citations', }, login: { title: 'Red Bear Memory Science', From 17d3c81c0231c6aab9bdd6757592b711cea12f75 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 19:06:55 +0800 Subject: [PATCH 116/120] fix(web): update i18n --- web/src/i18n/en.ts | 2 +- web/src/i18n/zh.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 4e631b26..c1d11f68 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1508,7 +1508,7 @@ export const en = { EPISODIC_MEMORY: 'Episodic Memory', FORGET_MEMORY: 'Forget Memory', - endUserProfile: 'Profile', + endUserProfile: 'Permanent Memory', editEndUserProfile: 'Edit', other_name: 'Name', position: 'Position', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 862ed5d4..60cb1c98 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1506,7 +1506,7 @@ export const zh = { EPISODIC_MEMORY: '情景记忆', FORGET_MEMORY: '遗忘记忆', - endUserProfile: '核心档案', + endUserProfile: '永久记忆', editEndUserProfile: '编辑', other_name: '名称', position: '职位', From 2dfc3b25d80c18c5306209aa9b7d227d2ee6df27 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 31 Mar 2026 12:26:46 +0800 Subject: [PATCH 117/120] [feat] User list pagination function --- .../memory_dashboard_controller.py | 118 +++++++++--------- api/app/services/memory_dashboard_service.py | 113 ++++++++++++++--- 2 files changed, 160 insertions(+), 71 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index fe4337d1..948154f1 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -1,3 +1,4 @@ +import asyncio from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -47,62 +48,62 @@ def get_workspace_total_end_users( @router.get("/end_users", response_model=ApiResponse) async def get_workspace_end_users( + workspace_id: Optional[str] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), + keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"), + page: int = Query(1, ge=1, description="页码,从1开始"), + pagesize: int = Query(10, ge=1, description="每页数量"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """ - 获取工作空间的宿主列表(高性能优化版本 v2) - - 优化策略: - 1. 批量查询 end_users(一次查询而非循环) - 2. 并发查询所有用户的记忆数量(Neo4j) - 3. RAG 模式使用批量查询(一次 SQL) - 4. 只返回必要字段减少数据传输 - 5. 添加短期缓存减少重复查询 - 6. 并发执行配置查询和记忆数量查询 - - 返回格式: - { - "end_user": {"id": "uuid", "other_name": "名称"}, - "memory_num": {"total": 数量}, - "memory_config": {"memory_config_id": "id", "memory_config_name": "名称"} - } + 获取工作空间的宿主列表(分页查询,支持模糊搜索) + + 返回工作空间下的宿主列表,支持分页查询和模糊搜索。 + 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id) + page: 页码(从1开始,默认1) + pagesize: 每页数量(默认10) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含宿主列表和分页信息 """ - import asyncio - import json - from app.aioRedis import aio_redis_get, aio_redis_set - - workspace_id = current_user.current_workspace_id - - # 尝试从缓存获取(30秒缓存) - cache_key = f"end_users:workspace:{workspace_id}" - try: - cached_data = await aio_redis_get(cache_key) - if cached_data: - api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}") - return success(data=json.loads(cached_data), msg="宿主列表获取成功") - except Exception as e: - api_logger.warning(f"Redis 缓存读取失败: {str(e)}") - + # 如果未提供 workspace_id,使用当前用户的工作空间 + if workspace_id is None: + workspace_id = current_user.current_workspace_id # 获取当前空间类型 current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") - - # 获取 end_users(已优化为批量查询) - end_users = memory_dashboard_service.get_workspace_end_users( + api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表: keyword={keyword}, page={page}, pagesize={pagesize}") + + # 获取分页的 end_users + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( db=db, workspace_id=workspace_id, - current_user=current_user + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword ) + + end_users = end_users_result.get("items", []) + total = end_users_result.get("total", 0) + if not end_users: - api_logger.info("工作空间下没有宿主") - # 缓存空结果,避免重复查询 - try: - await aio_redis_set(cache_key, json.dumps([]), expire=30) - except Exception as e: - api_logger.warning(f"Redis 缓存写入失败: {str(e)}") - return success(data=[], msg="宿主列表获取成功") - + api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}") + return success(data={ + "items": [], + "page": { + "page": page, + "pagesize": pagesize, + "total": total, + "hasnext": (page * pagesize) < total + } + }, msg="宿主列表获取成功") + end_user_ids = [str(user.id) for user in end_users] # 并发执行两个独立的查询任务 @@ -170,13 +171,13 @@ async def get_workspace_end_users( get_memory_configs(), get_memory_nums() ) - - # 构建结果(优化:使用列表推导式) - result = [] + + # 构建结果列表 + items = [] for end_user in end_users: user_id = str(end_user.id) config_info = memory_configs_map.get(user_id, {}) - result.append({ + items.append({ 'end_user': { 'id': user_id, 'other_name': end_user.other_name @@ -187,12 +188,6 @@ async def get_workspace_end_users( "memory_config_name": config_info.get("memory_config_name") } }) - - # 写入缓存(30秒过期) - try: - await aio_redis_set(cache_key, json.dumps(result), expire=30) - except Exception as e: - api_logger.warning(f"Redis 缓存写入失败: {str(e)}") # 触发社区聚类补全任务(异步,不阻塞接口响应) try: @@ -202,7 +197,18 @@ async def get_workspace_end_users( except Exception as e: api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") - api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") + # 构建分页响应 + result = { + "items": items, + "page": { + "page": page, + "pagesize": pagesize, + "total": total, + "hasnext": (page * pagesize) < total + } + } + + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条") return success(data=result, msg="宿主列表获取成功") diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index d0078088..3ab54561 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -1,11 +1,12 @@ from sqlalchemy.orm import Session -from typing import List, Optional +from sqlalchemy import desc, nullslast, or_, and_, cast, String +from typing import List, Optional, Dict, Any import uuid from fastapi import HTTPException from app.models.user_model import User from app.models.app_model import App -from app.models.end_user_model import EndUser +from app.models.end_user_model import EndUser, EndUser as EndUserModel from app.models.memory_increment_model import MemoryIncrement from app.repositories import ( @@ -49,44 +50,40 @@ def get_current_workspace_type( def get_workspace_end_users( - db: Session, - workspace_id: uuid.UUID, + db: Session, + workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: """获取工作空间的所有宿主(优化版本:减少数据库查询次数) - 返回结果按 created_at 从新到旧排序(NULL 值排在最后) """ business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: + + try: # 查询应用(ORM) apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - + if not apps_orm: business_logger.info("工作空间下没有应用") return [] - + # 提取所有 app_id # app_ids = [app.id for app in apps_orm] - # 批量查询所有 end_users(一次查询而非循环查询) # 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 - from app.models.end_user_model import EndUser as EndUserModel - from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( EndUserModel.workspace_id == workspace_id ).order_by( nullslast(desc(EndUserModel.created_at)), desc(EndUserModel.id) ).all() - + # 转换为 Pydantic 模型(只在需要时转换) end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] - + business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return end_users - + except HTTPException: raise except Exception as e: @@ -94,6 +91,92 @@ def get_workspace_end_users( raise +def get_workspace_end_users_paginated( + db: Session, + workspace_id: uuid.UUID, + current_user: User, + page: int, + pagesize: int, + keyword: Optional[str] = None +) -> Dict[str, Any]: + """获取工作空间的宿主列表(分页版本,支持模糊搜索) + + 返回结果按 created_at 从新到旧排序(NULL 值排在最后) + 支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + current_user: 当前用户 + page: 页码(从1开始) + pagesize: 每页数量 + keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id) + + Returns: + dict: 包含 items(宿主列表)和 total(总记录数)的字典 + """ + business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}") + + try: + # 构建基础查询 + base_query = db.query(EndUserModel).filter( + EndUserModel.workspace_id == workspace_id + ) + + # 构建搜索条件(过滤空字符串和None) + keyword = keyword.strip() if keyword else None + + if keyword: + keyword_pattern = f"%{keyword}%" + # 优先匹配 other_name,如果 other_name 为空则匹配 id + # 使用 OR 条件:匹配 other_name 不为空的数据,或者 other_name 为空但 id 匹配的数据 + base_query = base_query.filter( + or_( + # 情况1:other_name 不为空且匹配 keyword + and_( + EndUserModel.other_name != "", + EndUserModel.other_name.isnot(None), + EndUserModel.other_name.ilike(keyword_pattern) + ), + # 情况2:other_name 为空或 None,但 id 匹配 keyword + and_( + or_( + EndUserModel.other_name == "", + EndUserModel.other_name.is_(None) + ), + cast(EndUserModel.id, String).ilike(keyword_pattern) + ) + ) + ) + business_logger.info(f"应用模糊搜索: keyword={keyword}(优先匹配 other_name,无 other_name 时匹配 id)") + + # 获取总记录数 + total = base_query.count() + + if total == 0: + business_logger.info("工作空间下没有宿主") + return {"items": [], "total": 0} + + # 分页查询 + # 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 + end_users_orm = base_query.order_by( + nullslast(desc(EndUserModel.created_at)), + desc(EndUserModel.id) + ).offset((page - 1) * pagesize).limit(pagesize).all() + + # 转换为 Pydantic 模型 + end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] + + business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条") + return {"items": end_users, "total": total} + + except HTTPException: + raise + except Exception as e: + business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}") + raise + + def get_workspace_memory_increment( db: Session, workspace_id: uuid.UUID, From ab45b7abacf41dd090bc8aeca595ae47c4853e80 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 31 Mar 2026 16:30:29 +0800 Subject: [PATCH 118/120] [feat] Optimize the performance of the /end_users interface and introduce performance monitoring tools --- .../memory_dashboard_controller.py | 34 +++++++---------- .../repositories/memory_config_repository.py | 9 +++++ api/app/services/memory_storage_service.py | 31 ++++++++++++++++ api/app/utils/performance_timer.py | 37 +++++++++++++++++++ 4 files changed, 90 insertions(+), 21 deletions(-) create mode 100644 api/app/utils/performance_timer.py diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 948154f1..260ea670 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -77,7 +77,7 @@ async def get_workspace_end_users( workspace_id = current_user.current_workspace_id # 获取当前空间类型 current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表: keyword={keyword}, page={page}, pagesize={pagesize}") + api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}") # 获取分页的 end_users end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( @@ -105,7 +105,7 @@ async def get_workspace_end_users( }, msg="宿主列表获取成功") end_user_ids = [str(user.id) for user in end_users] - + # 并发执行两个独立的查询任务 async def get_memory_configs(): """获取记忆配置(在线程池中执行同步查询)""" @@ -117,7 +117,7 @@ async def get_workspace_end_users( except Exception as e: api_logger.error(f"批量获取记忆配置失败: {str(e)}") return {} - + async def get_memory_nums(): """获取记忆数量""" if current_workspace_type == "rag": @@ -131,26 +131,18 @@ async def get_workspace_end_users( except Exception as e: api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") return {uid: {"total": 0} for uid in end_user_ids} - + elif current_workspace_type == "neo4j": - # Neo4j 模式:并发查询(带并发限制) - # 使用信号量限制并发数,避免大量用户时压垮 Neo4j - MAX_CONCURRENT_QUERIES = 10 - semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES) - - async def get_neo4j_memory_num(end_user_id: str): - async with semaphore: - try: - return await memory_storage_service.search_all(end_user_id) - except Exception as e: - api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}") - return {"total": 0} - - memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids]) - return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))} - + # Neo4j 模式:批量查询(简化版本,只返回total) + try: + batch_result = await memory_storage_service.search_all_batch(end_user_ids) + return {uid: {"total": count} for uid, count in batch_result.items()} + except Exception as e: + api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}") + return {uid: {"total": 0} for uid in end_user_ids} + return {uid: {"total": 0} for uid in end_user_ids} - + # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 try: from app.celery_app import celery_app as _celery_app diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index e64d19a3..3139b851 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -78,6 +78,15 @@ class MemoryConfigRepository: OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count """ + # 批量查询多个用户的记忆数量(简化版本,只返回total) + SEARCH_FOR_ALL_BATCH = """ + MATCH (n) WHERE n.end_user_id IN $end_user_ids + RETURN + n.end_user_id as user_id, + count(n) as total + ORDER BY user_id + """ + # Extracted entity details within group/app/user SEARCH_FOR_DETIALS = """ MATCH (n:ExtractedEntity) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 58f3e8bd..b3a66734 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] return result +async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]: + """批量查询多个用户的记忆数量(简化版本,只返回total) + + Args: + end_user_ids: 用户ID列表 + + Returns: + Dict[str, int]: 以user_id为key的记忆数量字典 + 格式: {"user_id": total_count} + """ + if not end_user_ids: + return {} + + result = await _neo4j_connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=end_user_ids, + ) + + # 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回 + data = {} + for row in result: + data[row["user_id"]] = row["total"] + + # 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值 + for user_id in end_user_ids: + if user_id not in data: + data[user_id] = 0 + + return data + + async def analytics_hot_memory_tags( db: Session, current_user: User, diff --git a/api/app/utils/performance_timer.py b/api/app/utils/performance_timer.py new file mode 100644 index 00000000..6b0ec5d6 --- /dev/null +++ b/api/app/utils/performance_timer.py @@ -0,0 +1,37 @@ +""" +性能监控工具模块 + +提供代码块执行时间统计功能,用于接口性能分析。 +如需再次启用性能监控,只需在 controller 中导入 from app.utils.performance_timer import timer 并添加 with timer(...) 包裹需要监控的代码块即可 +""" + +import time +from contextlib import contextmanager +from app.core.logging_config import get_api_logger + +# 获取API专用日志器 +api_logger = get_api_logger() + + +@contextmanager +def timer(label: str, user_count: int = 0): + """上下文管理器:用于测量代码块执行时间 + + Args: + label: 统计标签,用于标识被测量的代码块 + user_count: 用户数,可选参数,用于记录处理的用户数量 + + Usage: + with timer("获取用户列表"): + users = get_users() + + with timer("批量处理", user_count=len(user_ids)): + process_users(user_ids) + """ + start = time.perf_counter() + try: + yield + finally: + elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 + extra_info = f", 用户数: {user_count}" if user_count > 0 else "" + api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}") From b5c5863b39cb374bed49fcbf46b2d3bc3e8b443b Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 31 Mar 2026 15:00:30 +0800 Subject: [PATCH 119/120] [feat] RAG storage adjustment returns data structure --- api/app/services/memory_dashboard_service.py | 23 +++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index d0078088..a52fb18f 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -638,7 +638,24 @@ def get_rag_content( business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}") continue - # 4. 返回结果 + # 4. 将所有 page_content 拼接后按角色分割为对话列表 + merged_text = "\n".join(page_contents) + conversations = [] + if merged_text.strip(): + import re + # 在任意位置匹配 "user:" 或 "assistant:",不限于行首 + parts = re.split(r'(user|assistant):', merged_text) + # parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...] + i = 1 + while i < len(parts) - 1: + role = parts[i].strip() + content = parts[i + 1].strip() + # 将 content 中的 \n 还原为真实换行 + content = content.replace("\\n", "\n") + if role in ("user", "assistant") and content: + conversations.append({"role": role, "content": content}) + i += 2 + result = { "page": { "page": page, @@ -646,10 +663,10 @@ def get_rag_content( "total": global_total, "hasnext": offset_end < global_total, }, - "items": page_contents + "items": conversations } - business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条") + business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话") return result except Exception as e: From 6d6338eb06388aedddbf9951f0c2dc40dd8f0596 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 1 Apr 2026 10:36:29 +0800 Subject: [PATCH 120/120] [changes] Modify the data format and improve the query logic. --- .../memory_dashboard_controller.py | 3 ++- api/app/services/memory_dashboard_service.py | 19 ++++++------------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 260ea670..bedee987 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -1,4 +1,5 @@ import asyncio +import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -48,7 +49,7 @@ def get_workspace_total_end_users( @router.get("/end_users", response_model=ApiResponse) async def get_workspace_end_users( - workspace_id: Optional[str] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), + workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"), page: int = Query(1, ge=1, description="页码,从1开始"), pagesize: int = Query(10, ge=1, description="每页数量"), diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 3ab54561..9fad5dfe 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -128,27 +128,20 @@ def get_workspace_end_users_paginated( if keyword: keyword_pattern = f"%{keyword}%" - # 优先匹配 other_name,如果 other_name 为空则匹配 id - # 使用 OR 条件:匹配 other_name 不为空的数据,或者 other_name 为空但 id 匹配的数据 + # other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效 base_query = base_query.filter( or_( - # 情况1:other_name 不为空且匹配 keyword - and_( - EndUserModel.other_name != "", - EndUserModel.other_name.isnot(None), - EndUserModel.other_name.ilike(keyword_pattern) - ), - # 情况2:other_name 为空或 None,但 id 匹配 keyword + EndUserModel.other_name.ilike(keyword_pattern), and_( or_( + EndUserModel.other_name.is_(None), EndUserModel.other_name == "", - EndUserModel.other_name.is_(None) ), - cast(EndUserModel.id, String).ilike(keyword_pattern) - ) + cast(EndUserModel.id, String).ilike(keyword_pattern), + ), ) ) - business_logger.info(f"应用模糊搜索: keyword={keyword}(优先匹配 other_name,无 other_name 时匹配 id)") + business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)") # 获取总记录数 total = base_query.count()