diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py new file mode 100644 index 00000000..a79d4cb2 --- /dev/null +++ b/api/app/cache/__init__.py @@ -0,0 +1,11 @@ +""" +Cache 缓存模块 + +提供各种缓存功能的统一入口 +""" +from .memory import EmotionMemoryCache, ImplicitMemoryCache + +__all__ = [ + "EmotionMemoryCache", + "ImplicitMemoryCache", +] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py new file mode 100644 index 00000000..4ada3153 --- /dev/null +++ b/api/app/cache/memory/__init__.py @@ -0,0 +1,12 @@ +""" +Memory 缓存模块 + +提供记忆系统相关的缓存功能 +""" +from .emotion_memory import EmotionMemoryCache +from .implicit_memory import ImplicitMemoryCache + +__all__ = [ + "EmotionMemoryCache", + "ImplicitMemoryCache", +] diff --git a/api/app/cache/memory/emotion_memory.py b/api/app/cache/memory/emotion_memory.py new file mode 100644 index 00000000..45ea90de --- /dev/null +++ b/api/app/cache/memory/emotion_memory.py @@ -0,0 +1,134 @@ +""" +Emotion Suggestions Cache + +情绪个性化建议缓存模块 +用于缓存用户的情绪个性化建议数据 +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + + +class EmotionMemoryCache: + """情绪建议缓存类""" + + # Key 前缀 + PREFIX = "cache:memory:emotion_memory" + + @classmethod + def _get_key(cls, *parts: str) -> str: + """生成 Redis key + + Args: + *parts: key 的各个部分 + + Returns: + 完整的 Redis key + """ + return ":".join([cls.PREFIX] + list(parts)) + + @classmethod + async def set_emotion_suggestions( + cls, + user_id: str, + suggestions_data: Dict[str, Any], + expire: int = 86400 + ) -> bool: + """设置用户情绪建议缓存 + + Args: + user_id: 用户ID(end_user_id) + suggestions_data: 建议数据字典,包含: + - health_summary: 健康状态摘要 + - suggestions: 建议列表 + - generated_at: 生成时间(可选) + expire: 过期时间(秒),默认24小时(86400秒) + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key("suggestions", user_id) + + # 添加生成时间戳 + if "generated_at" not in suggestions_data: + suggestions_data["generated_at"] = datetime.now().isoformat() + + # 添加缓存标记 + suggestions_data["cached"] = True + + value = json.dumps(suggestions_data, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]: + """获取用户情绪建议缓存 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 建议数据字典,如果不存在或已过期返回 None + """ + try: + key = cls._get_key("suggestions", user_id) + value = await aio_redis.get(key) + + if value: + data = json.loads(value) + logger.info(f"成功获取情绪建议缓存: {key}") + return data + + logger.info(f"情绪建议缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_emotion_suggestions(cls, user_id: str) -> bool: + """删除用户情绪建议缓存 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key("suggestions", user_id) + result = await aio_redis.delete(key) + logger.info(f"删除情绪建议缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_suggestions_ttl(cls, user_id: str) -> int: + """获取情绪建议缓存的剩余过期时间 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 剩余秒数,-1表示永不过期,-2表示key不存在 + """ + try: + key = cls._get_key("suggestions", user_id) + ttl = await aio_redis.ttl(key) + logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒") + return ttl + except Exception as e: + logger.error(f"获取情绪建议缓存TTL失败: {e}") + return -2 diff --git a/api/app/cache/memory/implicit_memory.py b/api/app/cache/memory/implicit_memory.py new file mode 100644 index 00000000..21f08e9a --- /dev/null +++ b/api/app/cache/memory/implicit_memory.py @@ -0,0 +1,136 @@ +""" +Implicit Memory Profile Cache + +隐式记忆用户画像缓存模块 +用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯) +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + + +class ImplicitMemoryCache: + """隐式记忆用户画像缓存类""" + + # Key 前缀 + PREFIX = "cache:memory:implicit_memory" + + @classmethod + def _get_key(cls, *parts: str) -> str: + """生成 Redis key + + Args: + *parts: key 的各个部分 + + Returns: + 完整的 Redis key + """ + return ":".join([cls.PREFIX] + list(parts)) + + @classmethod + async def set_user_profile( + cls, + user_id: str, + profile_data: Dict[str, Any], + expire: int = 86400 + ) -> bool: + """设置用户完整画像缓存 + + Args: + user_id: 用户ID(end_user_id) + profile_data: 画像数据字典,包含: + - preferences: 偏好标签列表 + - portrait: 四维画像对象 + - interest_areas: 兴趣领域分布对象 + - habits: 行为习惯列表 + - generated_at: 生成时间(可选) + expire: 过期时间(秒),默认24小时(86400秒) + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key("profile", user_id) + + # 添加生成时间戳 + if "generated_at" not in profile_data: + profile_data["generated_at"] = datetime.now().isoformat() + + # 添加缓存标记 + profile_data["cached"] = True + + value = json.dumps(profile_data, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置用户画像缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]: + """获取用户完整画像缓存 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 画像数据字典,如果不存在或已过期返回 None + """ + try: + key = cls._get_key("profile", user_id) + value = await aio_redis.get(key) + + if value: + data = json.loads(value) + logger.info(f"成功获取用户画像缓存: {key}") + return data + + logger.info(f"用户画像缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取用户画像缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_user_profile(cls, user_id: str) -> bool: + """删除用户完整画像缓存 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key("profile", user_id) + result = await aio_redis.delete(key) + logger.info(f"删除用户画像缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除用户画像缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_profile_ttl(cls, user_id: str) -> int: + """获取用户画像缓存的剩余过期时间 + + Args: + user_id: 用户ID(end_user_id) + + Returns: + 剩余秒数,-1表示永不过期,-2表示key不存在 + """ + try: + key = cls._get_key("profile", user_id) + ttl = await aio_redis.ttl(key) + logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒") + return ttl + except Exception as e: + logger.error(f"获取用户画像缓存TTL失败: {e}") + return -2 diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py index baecdb3d..7d3ee686 100644 --- a/api/app/celery_worker.py +++ b/api/app/celery_worker.py @@ -3,6 +3,12 @@ Celery Worker 入口点 用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info """ from app.celery_app import celery_app +from app.core.logging_config import LoggingConfig, get_logger + +# Initialize logging system for Celery worker +LoggingConfig.setup_logging() +logger = get_logger(__name__) +logger.info("Celery worker logging initialized") # 导入任务模块以注册任务 import app.tasks diff --git a/api/app/controllers/emotion_controller.py b/api/app/controllers/emotion_controller.py index c92c04d5..b5cd7250 100644 --- a/api/app/controllers/emotion_controller.py +++ b/api/app/controllers/emotion_controller.py @@ -59,7 +59,7 @@ async def get_emotion_tags( "limit": request.limit } ) - + # 调用服务层 data = await emotion_service.get_emotion_tags( end_user_id=request.group_id, @@ -68,7 +68,7 @@ async def get_emotion_tags( end_date=request.end_date, limit=request.limit ) - + api_logger.info( "情绪标签统计获取成功", extra={ @@ -77,9 +77,9 @@ async def get_emotion_tags( "tags_count": len(data.get("tags", [])) } ) - + return success(data=data, msg="情绪标签获取成功") - + except Exception as e: api_logger.error( f"获取情绪标签统计失败: {str(e)}", @@ -108,14 +108,14 @@ async def get_emotion_wordcloud( "limit": request.limit } ) - + # 调用服务层 data = await emotion_service.get_emotion_wordcloud( end_user_id=request.group_id, emotion_type=request.emotion_type, limit=request.limit ) - + api_logger.info( "情绪词云数据获取成功", extra={ @@ -123,9 +123,9 @@ async def get_emotion_wordcloud( "total_keywords": data.get("total_keywords", 0) } ) - + return success(data=data, msg="情绪词云获取成功") - + except Exception as e: api_logger.error( f"获取情绪词云数据失败: {str(e)}", @@ -152,7 +152,7 @@ async def get_emotion_health( status_code=status.HTTP_400_BAD_REQUEST, detail="时间范围参数无效,必须是 7d、30d 或 90d" ) - + api_logger.info( f"用户 {current_user.username} 请求获取情绪健康指数", extra={ @@ -160,13 +160,13 @@ async def get_emotion_health( "time_range": request.time_range } ) - + # 调用服务层 data = await emotion_service.calculate_emotion_health_index( end_user_id=request.group_id, time_range=request.time_range ) - + api_logger.info( "情绪健康指数获取成功", extra={ @@ -175,9 +175,9 @@ async def get_emotion_health( "level": data.get("level", "未知") } ) - + return success(data=data, msg="情绪健康指数获取成功") - + except HTTPException: raise except Exception as e: @@ -200,12 +200,12 @@ async def get_emotion_suggestions( current_user: User = Depends(get_current_user), ): """获取个性化情绪建议(从缓存读取) - + Args: request: 包含 group_id 和可选的 config_id db: 数据库会话 current_user: 当前用户 - + Returns: 缓存的个性化情绪建议响应 """ @@ -217,13 +217,13 @@ async def get_emotion_suggestions( "config_id": request.config_id } ) - + # 从缓存获取建议 data = await emotion_service.get_cached_suggestions( end_user_id=request.group_id, db=db ) - + if data is None: # 缓存不存在或已过期 api_logger.info( @@ -231,11 +231,11 @@ async def get_emotion_suggestions( extra={"group_id": request.group_id} ) return fail( - BizCode.RESOURCE_NOT_FOUND, + BizCode.NOT_FOUND, "建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议", - None + "" ) - + api_logger.info( "个性化建议获取成功(缓存)", extra={ @@ -243,9 +243,9 @@ async def get_emotion_suggestions( "suggestions_count": len(data.get("suggestions", [])) } ) - + return success(data=data, msg="个性化建议获取成功(缓存)") - + except Exception as e: api_logger.error( f"获取个性化建议失败: {str(e)}", @@ -265,76 +265,51 @@ async def generate_emotion_suggestions( current_user: User = Depends(get_current_user), ): """生成个性化情绪建议(调用LLM并缓存) - + Args: - request: 包含 group_id、可选的 config_id 和 force_refresh + request: 包含 end_user_id db: 数据库会话 current_user: 当前用户 - + Returns: 新生成的个性化情绪建议响应 """ try: - # 验证 config_id(如果提供) - # 获取终端用户关联的配置 - config_id = request.config_id - if config_id is None: - # 如果没有提供 config_id,尝试获取用户关联的配置 - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - connected_config = get_end_user_connected_config(request.group_id, db) - config_id = connected_config.get("memory_config_id") - except ValueError as e: - return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e)) - else: - # 如果提供了 config_id,验证其有效性 - from app.services.memory_config_service import MemoryConfigService - try: - config_service = MemoryConfigService(db) - config = config_service.get_config_by_id(config_id) - if not config: - return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在") - except Exception as e: - return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e)) - api_logger.info( f"用户 {current_user.username} 请求生成个性化情绪建议", extra={ - "group_id": request.group_id, - "config_id": config_id + "end_user_id": request.end_user_id } ) - + # 调用服务层生成建议 data = await emotion_service.generate_emotion_suggestions( - end_user_id=request.group_id, + end_user_id=request.end_user_id, db=db ) - + # 保存到缓存 await emotion_service.save_suggestions_cache( - end_user_id=request.group_id, + end_user_id=request.end_user_id, suggestions_data=data, db=db, expires_hours=24 ) - + api_logger.info( "个性化建议生成成功", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "suggestions_count": len(data.get("suggestions", [])) } ) - + return success(data=data, msg="个性化建议生成成功") - + except Exception as e: api_logger.error( f"生成个性化建议失败: {str(e)}", - extra={"group_id": request.group_id}, + extra={"end_user_id": request.end_user_id}, exc_info=True ) raise HTTPException( diff --git a/api/app/controllers/implicit_memory_controller.py b/api/app/controllers/implicit_memory_controller.py index eb7037ff..62d1e428 100644 --- a/api/app/controllers/implicit_memory_controller.py +++ b/api/app/controllers/implicit_memory_controller.py @@ -161,9 +161,9 @@ async def get_preference_tags( if cached_profile is None: api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") return fail( - BizCode.RESOURCE_NOT_FOUND, + BizCode.NOT_FOUND, "画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像", - None + "" ) # Extract preferences from cache @@ -232,9 +232,9 @@ async def get_dimension_portrait( if cached_profile is None: api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") return fail( - BizCode.RESOURCE_NOT_FOUND, + BizCode.NOT_FOUND, "画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像", - None + "" ) # Extract portrait from cache @@ -280,9 +280,9 @@ async def get_interest_area_distribution( if cached_profile is None: api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") return fail( - BizCode.RESOURCE_NOT_FOUND, + BizCode.NOT_FOUND, "画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像", - None + "" ) # Extract interest areas from cache @@ -332,9 +332,9 @@ async def get_behavior_habits( if cached_profile is None: api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") return fail( - BizCode.RESOURCE_NOT_FOUND, + BizCode.NOT_FOUND, "画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像", - None + "" ) # Extract habits from cache diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index b0287d80..9be6e035 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,10 +1,11 @@ import asyncio import time +import uuid from app.core.logging_config import get_api_logger from app.core.memory.storage_services.reflection_engine.self_reflexion import ( ReflectionConfig, - ReflectionEngine, + ReflectionEngine, ReflectionRange, ReflectionBaseline, ) from app.core.response_utils import success from app.db import get_db @@ -39,9 +40,6 @@ async def save_reflection_config( db: Session = Depends(get_db), ) -> dict: """Save reflection configuration to data_comfig table""" - - - try: config_id = request.config_id if not config_id: @@ -52,51 +50,30 @@ async def save_reflection_config( api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") - update_params = { - "enable_self_reflexion": request.reflection_enabled, - "iteration_period": request.reflection_period_in_hours, - "reflexion_range": request.reflexion_range, - "baseline": request.baseline, - "reflection_model_id": request.reflection_model_id, - "memory_verify": request.memory_verify, - "quality_assessment": request.quality_assessment, - } + data_config = DataConfigRepository.update_reflection_config( + db, + config_id=config_id, + enable_self_reflexion=request.reflection_enabled, + iteration_period=request.reflection_period_in_hours, + reflexion_range=request.reflexion_range, + baseline=request.baseline, + reflection_model_id=request.reflection_model_id, + memory_verify=request.memory_verify, + quality_assessment=request.quality_assessment + ) - - - query, params = DataConfigRepository.build_update_reflection(config_id, **update_params) - - result = db.execute(text(query), params) - if result.rowcount == 0: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - db.commit() - - # 查询更新后的配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"更新后未找到config_id为 {config_id} 的配置" - ) - - api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") + db.refresh(data_config) reflection_result={ - "config_id": result.config_id, - "enable_self_reflexion": result.enable_self_reflexion, - "iteration_period": result.iteration_period, - "reflexion_range": result.reflexion_range, - "baseline": result.baseline, - "reflection_model_id": result.reflection_model_id, - "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id} + "config_id": data_config.config_id, + "enable_self_reflexion": data_config.enable_self_reflexion, + "iteration_period": data_config.iteration_period, + "reflexion_range": data_config.reflexion_range, + "baseline": data_config.baseline, + "reflection_model_id": data_config.reflection_model_id, + "memory_verify": data_config.memory_verify, + "quality_assessment": data_config.quality_assessment} return success(data=reflection_result, msg="反思配置成功") @@ -116,9 +93,8 @@ async def save_reflection_config( ) -@router.post("/reflection") +@router.get("/reflection") async def start_workspace_reflection( - config_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -143,11 +119,20 @@ async def start_workspace_reflection( end_users = data['end_users'] for base, config, user in zip(releases, data_configs, end_users): - if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 安全地转换为整数,处理空字符串和None的情况 + print(base['config']) + try: + base_config = int(base['config']) if base['config'] else 0 + config_id = int(config['config_id']) if config['config_id'] else 0 + except (ValueError, TypeError): + api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}") + continue + + if base_config == config_id and base['app_id'] == user['app_id']: # 调用反思服务 api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - reflection_result = await reflection_service.start_reflection_from_data( + reflection_result = await reflection_service.start_text_reflection( config_data=config, end_user_id=user['id'] ) @@ -178,17 +163,7 @@ async def start_reflection_configs( """通过config_id查询data_config表中的反思配置信息""" try: api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") - - # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) # 构建返回数据 reflection_config = { "config_id": result.config_id, @@ -198,8 +173,7 @@ async def start_reflection_configs( "baseline": result.baseline, "reflection_model_id": result.reflection_model_id, "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id + "quality_assessment": result.quality_assessment } api_logger.info(f"成功查询反思配置,config_id: {config_id}") return success(data=reflection_config, msg="反思配置查询成功") @@ -227,9 +201,7 @@ async def reflection_run( api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -242,7 +214,7 @@ async def reflection_run( model_id = result.reflection_model_id if model_id: try: - ModelConfigService.get_model_by_id(db=db, model_id=model_id) + ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id)) api_logger.info(f"模型ID验证成功: {model_id}") except Exception as e: api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}") @@ -252,8 +224,8 @@ async def reflection_run( config = ReflectionConfig( enabled=result.enable_self_reflexion, iteration_period=result.iteration_period, - reflexion_range=result.reflexion_range, - baseline=result.baseline, + reflexion_range=ReflectionRange(result.reflexion_range), + baseline=ReflectionBaseline(result.baseline), output_example='', memory_verify=result.memory_verify, quality_assessment=result.quality_assessment, diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index c58ecd6d..63d9078a 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,10 +1,8 @@ import os -import uuid from typing import Optional from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger -from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -458,18 +456,3 @@ async def get_recent_activity_stats_api( api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - - - -@router.get("/self_reflexion") -async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: - """ - 自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。 - - Args: - None - Returns: - 自我反思结果。 - """ - return await self_reflexion(host_id) - diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 04da05df..17ad70a7 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -8,9 +8,10 @@ from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger from app.core.response_utils import success -from app.db import get_db +from app.db import get_db, get_db_read from app.dependencies import get_share_user_id, ShareTokenData from app.repositories import knowledge_repository +from app.repositories.workflow_repository import WorkflowConfigRepository from app.schemas import release_share_schema, conversation_schema from app.schemas.response_schema import PageData, PageMeta from app.services import workspace_service @@ -19,7 +20,8 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.app_chat_service import AppChatService, get_app_chat_service -from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release +from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \ + agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) logger = get_business_logger() @@ -65,10 +67,10 @@ def get_or_generate_user_id(payload_user_id: str, request: Request) -> str: summary="获取访问 token" ) def get_access_token( - share_token: str, - payload: release_share_schema.TokenRequest, - request: Request, - db: Session = Depends(get_db), + share_token: str, + payload: release_share_schema.TokenRequest, + request: Request, + db: Session = Depends(get_db), ): """获取访问 token @@ -113,9 +115,9 @@ def get_access_token( response_model=None ) def get_shared_release( - password: str = Query(None, description="访问密码(如果需要)"), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), + password: str = Query(None, description="访问密码(如果需要)"), + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), ): """获取公开分享的发布版本信息 @@ -137,9 +139,9 @@ def get_shared_release( summary="验证访问密码" ) def verify_password( - payload: release_share_schema.PasswordVerifyRequest, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), + payload: release_share_schema.PasswordVerifyRequest, + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), ): """验证分享的访问密码 @@ -159,11 +161,11 @@ def verify_password( summary="获取嵌入代码" ) def get_embed_code( - width: str = Query("100%", description="iframe 宽度"), - height: str = Query("600px", description="iframe 高度"), - request: Request = None, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), + width: str = Query("100%", description="iframe 宽度"), + height: str = Query("600px", description="iframe 高度"), + request: Request = None, + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), ): """获取嵌入代码 @@ -183,7 +185,6 @@ def get_embed_code( return success(data=embed_code) - # ---------- 会话管理接口 ---------- @router.get( @@ -191,11 +192,11 @@ def get_embed_code( summary="获取会话列表" ) def list_conversations( - password: str = Query(None, description="访问密码"), - page: int = Query(1, ge=1), - pagesize: int = Query(20, ge=1, le=100), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), + password: str = Query(None, description="访问密码"), + page: int = Query(1, ge=1), + pagesize: int = Query(20, ge=1, le=100), + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), ): """获取分享应用的会话列表 @@ -209,9 +210,9 @@ def list_conversations( from app.repositories.end_user_repository import EndUserRepository end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( - app_id=share.app_id, - other_id=other_id - ) + app_id=share.app_id, + other_id=other_id + ) logger.debug(new_end_user.id) service = SharedChatService(db) conversations, total = service.list_conversations( @@ -233,10 +234,10 @@ def list_conversations( summary="获取会话详情(含消息)" ) def get_conversation( - conversation_id: uuid.UUID, - password: str = Query(None, description="访问密码"), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), + conversation_id: uuid.UUID, + password: str = Query(None, description="访问密码"), + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), ): """获取会话详情和消息历史""" chat_service = SharedChatService(db) @@ -266,10 +267,10 @@ def get_conversation( summary="发送消息(支持流式和非流式)" ) async def chat( - payload: conversation_schema.ChatRequest, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), - app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None, + payload: conversation_schema.ChatRequest, + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), + app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None, ): """发送消息并获取回复 @@ -313,7 +314,7 @@ async def chat( ) end_user_id = str(new_end_user.id) - appid=share.app_id + appid = share.app_id """获取存储类型和工作空间的ID""" # 直接通过 SQLAlchemy 查询 app @@ -425,16 +426,16 @@ async def chat( # ) async def event_generator(): async for event in app_chat_service.agnet_chat_stream( - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id= str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - web_search=payload.web_search, - config=agent_config, - memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - workspace_id=workspace_id + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=str(new_end_user.id), # 转换为字符串 + variables=payload.variables, + web_search=payload.web_search, + config=agent_config, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -481,15 +482,15 @@ async def chat( async def event_generator(): async for event in app_chat_service.multi_agent_chat_stream( - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - config=config, - web_search=payload.web_search, - memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=str(new_end_user.id), # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -561,24 +562,27 @@ async def chat( # return success(data=conversation_schema.ChatResponse(**result)) elif app_type == AppType.WORKFLOW: - config = workflow_config_4_app_release(release) + if not config.id: + with get_db_read() as db: + source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id) + config.id = source_config.id + config.id = uuid.UUID(config.id) if payload.stream: async def event_generator(): - async for event in app_chat_service.workflow_chat_stream( - - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=end_user_id, # 转换为字符串 - variables=payload.variables, - config=config, - web_search=payload.web_search, - memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - app_id=release.app_id, - workspace_id=workspace_id + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=release.app_id, + workspace_id=workspace_id, + release_id=release.id ): event_type = event.get("event", "message") event_data = event.get("data", {}) @@ -610,7 +614,8 @@ async def chat( storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, app_id=release.app_id, - workspace_id=workspace_id + workspace_id=workspace_id, + release_id=release.id ) logger.debug( "工作流试运行返回结果", diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 583b4700..677e1623 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -242,8 +242,9 @@ async def chat( memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, - app_id=app.app_id, - workspace_id=workspace_id + app_id=app.id, + workspace_id=workspace_id, + release_id=app.current_release.id, ): event_type = event.get("event", "message") event_data = event.get("data", {}) @@ -274,8 +275,9 @@ async def chat( memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, - app_id=app.app_id, - workspace_id=workspace_id + app_id=app.id, + workspace_id=workspace_id, + release_id=app.current_release.id ) logger.debug( "工作流试运行返回结果", diff --git a/api/app/core/config.py b/api/app/core/config.py index 5f4f91c4..9600b551 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -38,6 +38,7 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") + # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") @@ -146,6 +147,7 @@ class Settings: # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) + REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) diff --git a/api/app/core/memory/agent/__init__.py b/api/app/core/memory/agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/app/core/memory/agent/langgraph_graph/__init__.py b/api/app/core/memory/agent/langgraph_graph/__init__.py deleted file mode 100644 index a0596e38..00000000 --- a/api/app/core/memory/agent/langgraph_graph/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -LangGraph Graph package for memory agent. - -This package provides the LangGraph workflow orchestrator with modular -node implementations, routing logic, and state management. - -Package structure: -- read_graph: Main graph factory for read operations -- write_graph: Main graph factory for write operations -- nodes: LangGraph node implementations -- routing: State routing logic -- state: State management utilities -""" -from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph - -__all__ = ['make_read_graph'] \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py index 4e808919..231a167c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py @@ -4,7 +4,7 @@ LangGraph node implementations. This module contains custom node implementations for the LangGraph workflow. """ -from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode -from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message - -__all__ = ["ToolExecutionNode", "create_input_message"] +# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode +# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message +# +# __all__ = ["ToolExecutionNode", "create_input_message"] diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py new file mode 100644 index 00000000..6595a2ce --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -0,0 +1,16 @@ +from app.core.memory.agent.utils.llm_tools import ReadState, WriteState + + +def content_input_node(state: ReadState) -> ReadState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} + +def content_input_write(state: WriteState) -> WriteState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py deleted file mode 100644 index 3eed497f..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Input node for LangGraph workflow entry point. - -This module provides the create_input_message function which processes initial -user input with multimodal support and creates the first tool call message. -""" - -import logging -import re -import uuid -from datetime import datetime -from typing import Any, Dict - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage - -logger = logging.getLogger(__name__) - - -async def create_input_message( - state: Dict[str, Any], - tool_name: str, - session_id: str, - search_switch: str, - apply_id: str, - group_id: str, - multimodal_processor: MultimodalProcessor, - memory_config: MemoryConfig, -) -> Dict[str, Any]: - """ - Create initial tool call message from user input. - - This function: - 1. Extracts the last message content from state - 2. Processes multimodal inputs (images/audio) using the multimodal processor - 3. Generates a unique message ID - 4. Extracts namespace from session_id - 5. Handles verified_data extraction for backward compatibility - 6. Returns AIMessage with complete tool_calls structure - - Args: - state: LangGraph state dictionary containing messages - tool_name: Name of the tool to invoke (typically "Split_The_Problem") - session_id: Session identifier (format: "call_id_{namespace}") - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - multimodal_processor: Processor for handling image/audio inputs - memory_config: MemoryConfig object containing all configuration - - Returns: - State update with AIMessage containing tool_call - - Examples: - >>> state = {"messages": [HumanMessage(content="What is AI?")]} - >>> result = await create_input_message( - ... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config - ... ) - >>> result["messages"][0].tool_calls[0]["name"] - 'Split_The_Problem' - """ - messages = state.get("messages", []) - - # Extract last message content - if messages: - last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1]) - else: - logger.warning("[create_input_message] No messages in state, using empty string") - last_message = "" - - logger.debug(f"[create_input_message] Original input: {last_message[:100]}...") - - # Process multimodal input (images/audio) - try: - processed_content = await multimodal_processor.process_input(last_message) - if processed_content != last_message: - logger.info( - f"[create_input_message] Multimodal processing converted input " - f"from {len(last_message)} to {len(processed_content)} chars" - ) - last_message = processed_content - except Exception as e: - logger.error( - f"[create_input_message] Multimodal processing failed: {e}", - exc_info=True - ) - # Continue with original content - - # Generate unique message ID - uuid_str = uuid.uuid4() - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Extract namespace from session_id - # Expected format: "call_id_{namespace}" or similar - try: - namespace = str(session_id).split('_id_')[1] - except (IndexError, AttributeError): - logger.warning( - f"[create_input_message] Could not extract namespace from session_id: {session_id}" - ) - namespace = "unknown" - - # Handle verified_data extraction (backward compatibility) - # This regex-based extraction is kept for compatibility with existing data formats - if 'verified_data' in str(last_message): - try: - messages_last = str(last_message).replace('\\n', '').replace('\\', '') - query_match = re.findall(r'"query": "(.*?)",', messages_last) - if query_match: - last_message = query_match[0] - logger.debug( - f"[create_input_message] Extracted query from verified_data: {last_message}" - ) - except Exception as e: - logger.warning( - f"[create_input_message] Failed to extract query from verified_data: {e}" - ) - - # Construct tool call message - tool_call_id = f"{session_id}_{uuid_str}" - - logger.info( - f"[create_input_message] Creating tool call for '{tool_name}' " - f"with ID: {tool_call_id}" - ) - - # Build tool arguments - tool_args = { - "sentence": last_message, - "sessionid": session_id, - "messages_id": str(uuid_str), - "search_switch": search_switch, - "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, - } - - return { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": tool_name, - "args": tool_args, - "id": tool_call_id - }] - ) - ] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py new file mode 100644 index 00000000..0c68a47e --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -0,0 +1,237 @@ +import json +import time +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class ProblemNodeService(LLMServiceMixin): + """问题处理节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +problem_service = ProblemNodeService() + +async def Split_The_Problem(state: ReadState) -> ReadState: + """问题分解节点""" + # 从状态中获取数据 + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='problem_breakdown_prompt.jinja2', + operation_name='split_the_problem', + history=history, + sentence=content + ) + + try: + # 使用优化的LLM服务 + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + # 添加更详细的日志记录 + logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") + + # 验证结构化响应 + if not structured or not hasattr(structured, 'root'): + logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") + split_result = json.dumps([], ensure_ascii=False) + elif not structured.root: + logger.warning("Split_The_Problem: 结构化响应的root为空") + split_result = json.dumps([], ensure_ascii=False) + else: + split_result = json.dumps( + [item.model_dump() for item in structured.root], + ensure_ascii=False + ) + + split_result_dict = [] + for index, item in enumerate(json.loads(split_result)): + split_data = { + "id": f"Q{index+1}", + "question": item['extended_question'], + "type": item['type'], + "reason": item['reason'] + } + split_result_dict.append(split_data) + + logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") + + result = { + "context": split_result, + "original": content, + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": split_result_dict, + "original_query": content + } + } + + except Exception as e: + logger.error( + f"Split_The_Problem failed: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "content_length": len(content), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Split_The_Problem error details: {error_details}") + + # 创建默认的空结果 + result = { + "context": json.dumps([], ensure_ascii=False), + "original": content, + "error": str(e), + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": [], + "original_query": content, + "error": error_details + } + } + + # 返回更新后的状态,包含spit_context字段 + return {"spit_data": result} + +async def Problem_Extension(state: ReadState) -> ReadState: + """问题扩展节点""" + # 获取原始数据和分解结果 + start = time.time() + content = state.get('data', '') + data = state.get('spit_data', '')['context'] + group_id = state.get('group_id', '') + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + memory_config = state.get('memory_config', None) + + databasets = {} + try: + data = json.loads(data) + for i in data: + databasets[i['extended_question']] = i['type'] + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.error(f"Problem_Extension: 数据解析失败: {e}") + # 使用空字典作为fallback + databasets = {} + data = [] + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='Problem_Extension_prompt.jinja2', + operation_name='problem_extension', + history=history, + questions=databasets + ) + + try: + # 使用优化的LLM服务 + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") + + # 验证结构化响应 + if not response_content or not hasattr(response_content, 'root'): + logger.warning("Problem_Extension: 结构化响应为空或格式不正确") + aggregated_dict = {} + elif not response_content.root: + logger.warning("Problem_Extension: 结构化响应的root为空") + aggregated_dict = {} + else: + # Aggregate results by original question + aggregated_dict = {} + for item in response_content.root: + try: + key = getattr(item, "original_question", None) or ( + item.get("original_question") if isinstance(item, dict) else None + ) + value = getattr(item, "extended_question", None) or ( + item.get("extended_question") if isinstance(item, dict) else None + ) + if not key or not value: + logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}") + continue + aggregated_dict.setdefault(key, []).append(value) + except Exception as item_error: + logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}") + continue + + logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组") + + except Exception as e: + logger.error( + f"LLM call failed for Problem_Extension: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "questions_count": len(databasets), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Problem_Extension error details: {error_details}") + aggregated_dict = {} + + logger.info("Problem extension") + logger.info(f"Problem extension result: {aggregated_dict}") + + # Emit intermediate output for frontend + print(time.time() - start) + result = { + "context": aggregated_dict, + "original": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "problem_extension", + "title": "问题扩展", + "data": aggregated_dict, + "original_query": content, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return {"problem_extension": result} + + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py new file mode 100644 index 00000000..14f8fa8b --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -0,0 +1,417 @@ +# ===== 标准库 ===== +import asyncio +import json +import os + +# ===== 第三方库 ===== +from langchain.agents import create_agent +from langchain_openai import ChatOpenAI +from app.core.logging_config import get_agent_logger +from app.db import get_db, get_db_context + +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService + +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + COUNTState, + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) +from app.core.memory.agent.langgraph_graph.tools.tool import ( + create_hybrid_retrieval_tool_sync, + create_time_retrieval_tool, + extract_tool_message_content, +) + +from app.core.rag.nlp.search import knowledge_retrieval + +logger = get_agent_logger(__name__) +db = next(get_db()) + + + +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config +async def rag_knowledge(state,question): + kb_config = await rag_config(state) + group_id = state.get('group_id', '') + user_rag_memory_id=state.get("user_rag_memory_id",'') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception : + retrieval_knowledge=[] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge,clean_content,cleaned_query,raw_results + + +async def llm_infomation(state: ReadState) -> ReadState: + memory_config = state.get('memory_config', None) + model_id = memory_config.llm_model_id + tenant_id = memory_config.tenant_id + + # 使用现有的 memory_config 而不是重新查询数据库 + # 或者使用线程安全的数据库访问 + with get_db_context() as db: + result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) + result_pydantic = model_schema.ModelConfig.model_validate(result_orm) + return result_pydantic + + +async def clean_databases(data) -> str: + """ + 简化的数据库搜索结果清理函数 + + Args: + data: 搜索结果数据 + + Returns: + 清理后的内容字符串 + """ + try: + # 解析JSON字符串 + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError: + return data + + if not isinstance(data, dict): + return str(data) + + # 获取结果数据 + # with open("搜索结果.json","w",encoding='utf-8') as f: + # f.write(json.dumps(data, indent=4, ensure_ascii=False)) + results = data.get('results', data) + if not isinstance(results, dict): + return str(results) + + # 收集所有内容 + content_list = [] + + # 处理重排序结果 + reranked = results.get('reranked_results', {}) + if reranked: + for category in ['summaries', 'statements', 'chunks', 'entities']: + items = reranked.get(category, []) + if isinstance(items, list): + content_list.extend(items) + # 处理时间搜索结果 + time_search = results.get('time_search', {}) + if time_search: + if isinstance(time_search, dict): + statements = time_search.get('statements', time_search.get('time_search', [])) + if isinstance(statements, list): + content_list.extend(statements) + elif isinstance(time_search, list): + content_list.extend(time_search) + + # 提取文本内容 + text_parts = [] + for item in content_list: + if isinstance(item, dict): + text = item.get('statement') or item.get('content', '') + if text: + text_parts.append(text) + elif isinstance(item, str): + text_parts.append(item) + + + return '\n'.join(text_parts).strip() + + except Exception as e: + logger.error(f"clean_databases failed: {e}", exc_info=True) + return str(data) + + +async def retrieve_nodes(state: ReadState) -> ReadState: + + ''' + + 模型信息 + ''' + + problem_extension=state.get('problem_extension', '')['context'] + storage_type=state.get('storage_type', '') + user_rag_memory_id=state.get('user_rag_memory_id', '') + group_id=state.get('group_id', '') + memory_config = state.get('memory_config', None) + original=state.get('data', '') + problem_list=[] + for key,values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 + async def process_question_nodes(idx, question): + try: + # Prepare search parameters based on storage type + search_params = { + "group_id": group_id, + "question": question, + "return_raw_results": True + } + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search( + **search_params, memory_config=memory_config + ) + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j!=['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve':dup_databases} + + + + +async def retrieve(state: ReadState) -> ReadState: + # 从state中获取group_id + import time + start=time.time() + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + databases_anser = [] + + async def get_llm_info(): + with get_db_context() as db: # 使用同步数据库上下文管理器 + config_service = MemoryConfigService(db) + return await llm_infomation(state) + llm_config = await get_llm_info() + api_key_obj = llm_config.api_keys[0] + api_key = api_key_obj.api_key + api_base = api_key_obj.api_base + model_name = api_key_obj.model_name + llm = ChatOpenAI( + model=model_name, + api_key=api_key, + base_url=api_base, + temperature=0.2, + ) + + time_retrieval_tool = create_time_retrieval_tool(group_id) + search_params = { "group_id": group_id, "return_raw_results": True } + hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + agent = create_agent( + llm, + tools=[time_retrieval_tool,hybrid_retrieval], + system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" + ) + + # 创建异步任务处理单个问题 + import asyncio + + # 在模块级别定义信号量,限制最大并发数 + SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 + + async def process_question(idx, question): + async with SEMAPHORE: # 限制并发 + try: + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + cleaned_query = question + # 使用 asyncio 在线程池中运行同步的 agent.invoke + import asyncio + response = await asyncio.get_event_loop().run_in_executor( + None, + lambda: agent.invoke({"messages": question}) + ) + tool_results = extract_tool_message_content(response) + if tool_results == None: + raw_results = [] + clean_content = '' + else: + raw_results = tool_results['content'] + clean_content = await clean_databases(raw_results) + + try: + raw_results = raw_results['results'] + except Exception: + raw_results = [] + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + import asyncio + tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j != ['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + # with open('retrieve_text.json', 'w') as f: + # json.dump(dup_databases, f, indent=4) + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve': dup_databases} + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py new file mode 100644 index 00000000..7b727da5 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -0,0 +1,303 @@ + + +import time + +from app.core.logging_config import get_agent_logger, log_time +from app.db import get_db + +from app.core.memory.agent.models.summary_models import ( + RetrieveSummaryResponse, + SummaryResponse, +) +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +logger = get_agent_logger(__name__) +db_session = next(get_db()) + +class SummaryNodeService(LLMServiceMixin): + """总结节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +summary_service = SummaryNodeService() + +async def summary_history(state: ReadState) -> ReadState: + group_id = state.get("group_id", '') + history = await SessionService(store).get_history(group_id, group_id, group_id) + return history + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + """ + 增强的summary_llm函数,包含更好的错误处理和数据验证 + """ + data = state.get("data", '') + + # 构建系统提示词 + if str(search_mode) == "0": + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + data=retrieve_info, + query=data + ) + else: + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + query=data, + history=history, + retrieve_info=retrieve_info + ) + try: + # 使用优化的LLM服务进行结构化输出 + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) + # 验证结构化响应 + if structured is None: + logger.warning(f"LLM返回None,使用默认回答") + return "信息不足,无法回答" + + # 根据操作类型提取答案 + if operation_name == "summary": + aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" + else: + # 处理RetrieveSummaryResponse + if hasattr(structured, 'data') and structured.data: + aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" + else: + logger.warning(f"结构化响应缺少data字段") + aimessages = "信息不足,无法回答" + + # 验证答案不为空 + if not aimessages or aimessages.strip() == "": + aimessages = "信息不足,无法回答" + + return aimessages + + except Exception as e: + logger.error(f"结构化输出失败: {e}", exc_info=True) + + # 尝试非结构化输出作为fallback + try: + logger.info("尝试非结构化输出作为fallback") + response = await summary_service.call_llm_simple( + state=state, + db_session=db_session, + system_prompt=system_prompt, + fallback_message="信息不足,无法回答" + ) + + if response and response.strip(): + # 简单清理响应 + cleaned_response = response.strip() + # 移除可能的JSON标记 + if cleaned_response.startswith('```'): + lines = cleaned_response.split('\n') + cleaned_response = '\n'.join(lines[1:-1]) + + return cleaned_response + else: + return "信息不足,无法回答" + + except Exception as fallback_error: + logger.error(f"Fallback也失败: {fallback_error}") + return "信息不足,无法回答" + +async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + data = state.get("data", '') + group_id = state.get("group_id", '') + await SessionService(store).save_session( + user_id=group_id, + query=data, + apply_id=group_id, + group_id=group_id, + ai_response=aimessages + ) + await SessionService(store).cleanup_duplicates() + logger.info(f"sessionid: {aimessages} 写入成功") +async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: + storage_type=state.get("storage_type",'') + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + input_summary = { + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "input_summary", + "title": "快速答案", + "summary": aimessages, + "query": data, + "raw_results": raw_results, + "search_mode": "quick_search", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + retrieve={ + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "retrieval_summary", + "title":"快速检索", + "summary": aimessages, + "query": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return input_summary,retrieve + +async def Input_Summary(state: ReadState) -> ReadState: + start=time.time() + storage_type=state.get("storage_type",'') + memory_config = state.get('memory_config', None) + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + group_id=state.get("group_id", '') + logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + history = await summary_history( state) + search_params = { + "group_id": group_id, + "question": data, + "return_raw_results": True + } + + try: + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + except Exception as e: + logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + retrieve_info, question, raw_results = "", data, [] + + + try: + # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', + # 'input_summary',RetrieveSummaryResponse) + # logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}") + summary_result = await summary_prompt(state, retrieve_info, retrieve_info) + summary = summary_result[0] + except Exception as e: + logger.error( f"Input_Summary failed: {e}", exc_info=True ) + summary= { + "status": "fail", + "summary_result": "信息不足,无法回答", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "error": str(e) + } + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('检索', duration) + return {"summary":summary} + +async def Retrieve_Summary(state: ReadState)-> ReadState: + retrieve=state.get("retrieve", '') + history = await summary_history( state) + import json + with open("检索.json","w",encoding='utf-8') as f: + f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) + retrieve=retrieve.get("Expansion_issue", []) + start=time.time() + retrieve_info_str=[] + for data in retrieve: + if data=='': + retrieve_info_str='' + else: + for key, value in data.items(): + if key=='Answer_Small': + for i in value: + retrieve_info_str.append(i) + retrieve_info_str=list(set(retrieve_info_str)) + retrieve_info_str='\n'.join(retrieve_info_str) + + aimessages=await summary_llm(state,history,retrieve_info_str, + 'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + logger.info(f"Summary after retrieval: {aimessages}") + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + + +async def Summary(state: ReadState)-> ReadState: + start=time.time() + query = state.get("data", '') + verify=state.get("verify", '') + verify_expansion_issue=verify.get("verified_data", '') + retrieve_info_str='' + for data in verify_expansion_issue: + for key, value in data.items(): + if key=='answer_small': + for i in value: + retrieve_info_str+=i+'\n' + history=await summary_history(state) + + data = { + "query": query, + "history": history, + "retrieve_info": retrieve_info_str + } + aimessages=await summary_llm(state,history,data, + 'summary_prompt.jinja2','summary',SummaryResponse,0) + + + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + try: + duration = time.time() - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + +async def Summary_fails(state: ReadState)-> ReadState: + storage_type=state.get("storage_type", '') + user_rag_memory_id=state.get("user_rag_memory_id", '') + result= { + "status": "success", + "summary_result": "没有相关数据", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + return {"summary":result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py deleted file mode 100644 index 4727fb9c..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Tool execution node for LangGraph workflow. - -This module provides the ToolExecutionNode class which wraps tool execution -with parameter transformation logic using the ParameterBuilder service. -""" - -import logging -import time -from typing import Any, Callable, Dict - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_content_payload, - extract_tool_call_id, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage -from langgraph.prebuilt import ToolNode - -logger = logging.getLogger(__name__) - - -class ToolExecutionNode: - """ - Custom LangGraph node that wraps tool execution with parameter transformation. - - This node extracts content from previous tool results, transforms parameters - based on tool type using ParameterBuilder, and invokes the tool with the - correct argument structure. - - Attributes: - tool_node: LangGraph ToolNode wrapping the actual tool - id: Node identifier for message IDs - tool_name: Name of the tool being executed - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - memory_config: MemoryConfig object containing all configuration - """ - - def __init__( - self, - tool: Callable, - node_id: str, - namespace: str, - search_switch: str, - apply_id: str, - group_id: str, - parameter_builder: ParameterBuilder, - storage_type: str, - user_rag_memory_id: str, - memory_config: MemoryConfig, - ): - """ - Initialize the tool execution node. - - Args: - tool: The tool function to execute - node_id: Identifier for this node (used in message IDs) - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - storage_type: Storage type for the workspace - user_rag_memory_id: User RAG memory identifier - memory_config: MemoryConfig object containing all configuration - """ - self.tool_node = ToolNode([tool]) - self.id = node_id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.parameter_builder = parameter_builder - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - self.memory_config = memory_config - - logger.info( - f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'" - ) - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute the tool with transformed parameters. - - This method: - 1. Extracts the last message from state - 2. Extracts tool call ID using state extractors - 3. Extracts content payload using state extractors - 4. Builds tool arguments using parameter builder - 5. Constructs AIMessage with tool_calls - 6. Invokes the tool and returns the result - - Args: - state: LangGraph state dictionary - - Returns: - Updated state with tool result in messages - """ - messages = state.get("messages", []) - logger.debug( self.tool_name) - - if not messages: - logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state") - return {"messages": [AIMessage(content="Error: No messages in state")]} - - last_message = messages[-1] - logger.debug( - f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}" - ) - - try: - # Extract tool call ID using state extractors - tool_call_id = extract_tool_call_id(last_message) - logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}") - - except ValueError as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}" - ) - return {"messages": [AIMessage(content=f"Error: {str(e)}")]} - - try: - # Extract content payload using state extractors - content = extract_content_payload(last_message) - logger.debug( - f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}" - ) - # Log raw message content for debugging - if hasattr(last_message, 'content'): - raw = last_message.content - logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}") - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}", - exc_info=True - ) - content = {} - - try: - # Build tool arguments using parameter builder - tool_args = self.parameter_builder.build_tool_args( - tool_name=self.tool_name, - content=content, - tool_call_id=tool_call_id, - search_switch=self.search_switch, - apply_id=self.apply_id, - group_id=self.group_id, - memory_config=self.memory_config, - storage_type=self.storage_type, - user_rag_memory_id=self.user_rag_memory_id, - ) - logger.debug( - f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}" - ) - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}", - exc_info=True - ) - return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]} - - # Construct tool input message - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": f"{self.id}_{tool_call_id}", - }] - ) - ] - } - - try: - # Invoke the tool - result = await self.tool_node.ainvoke(tool_input) - - logger.debug( - f"[ToolExecutionNode] {self.id} - Tool execution completed" - ) - - # Check for error in tool response - error_entry = None - if result and "messages" in result: - for msg in result["messages"]: - if hasattr(msg, 'content'): - try: - import json - content = msg.content - if isinstance(content, str): - parsed = json.loads(content) - if isinstance(parsed, dict) and "error" in parsed: - error_msg = parsed["error"] - logger.warning( - f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}" - ) - error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id} - except (json.JSONDecodeError, TypeError): - pass - - # Return result with error tracking if error was found - if error_entry: - result["errors"] = [error_entry] - - return result - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}", - exc_info=True - ) - # Track error in state and return error message - from langchain_core.messages import ToolMessage - error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id} - return { - "messages": [ - ToolMessage( - content=f"Error executing tool: {str(e)}", - tool_call_id=f"{self.id}_{tool_call_id}" - ) - ], - "errors": [error_entry] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py new file mode 100644 index 00000000..f3a39afb --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -0,0 +1,85 @@ + +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class VerificationNodeService(LLMServiceMixin): + """验证节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +verification_service = VerificationNodeService() + +async def Verify_prompt(state: ReadState,messages_deal): + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + data = state.get('data', '') + Verify_result = { + "status": messages_deal.split_result, + "verified_data": messages_deal.expansion_issue, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "verification", + "title": "Data Verification", + "result": messages_deal.split_result, + "reason": messages_deal.reason, + "query": data, + "verified_count": len(messages_deal.expansion_issue), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + return Verify_result +async def Verify(state: ReadState): + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + + retrieve = state.get("retrieve", '') + retrieve = retrieve.get("Expansion_issue", []) + messages = { + "Query": content, + "Expansion_issue": retrieve + } + + system_prompt = await verification_service.template_service.render_template( + template_name='split_verify_prompt.jinja2', + operation_name='split_verify_prompt', + history=history, + sentence=messages + ) + + # 使用优化的LLM服务 + structured = await verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "split_result": "fail", + "expansion_issue": [], + "reason": "验证失败" + } + ) + + result = await Verify_prompt(state, structured) + return {"verify": result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py new file mode 100644 index 00000000..8421d059 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -0,0 +1,50 @@ + +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.utils.write_tools import write +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def write_node(state: WriteState) -> WriteState: + """ + Write data to the database/file system. + + Args: + ctx: FastMCP context for dependency injection + content: Data content to write + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + memory_config: MemoryConfig object containing all configuration + + Returns: + dict: Contains 'status', 'saved_to', and 'data' fields + """ + content=state.get('data','') + group_id=state.get('group_id','') + memory_config=state.get('memory_config', '') + try: + result=await write( + content=content, + user_id=group_id, + apply_id=group_id, + group_id=group_id, + memory_config=memory_config, + ) + logger.info(f"Write completed successfully! Config: {memory_config.config_name}") + + write_result= { + "status": "success", + "data": content, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name, + } + return {"write_result":write_result} + + + except Exception as e: + logger.error(f"Data_write failed: {e}", exc_info=True) + write_result= { + "status": "error", + "message": str(e), + } + return {"write_result": write_result} diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index c29b5d86..19011a5f 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -1,469 +1,177 @@ -import json -import os -import re -import time -import warnings +#!/usr/bin/env python3 from contextlib import asynccontextmanager -from typing import Literal -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.nodes import ( - ToolExecutionNode, - create_input_message, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from langchain_core.messages import AIMessage -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.constants import END, START +from langchain_core.messages import HumanMessage +from langgraph.constants import START, END from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode - -logger = get_agent_logger(__name__) - -warnings.filterwarnings("ignore", category=RuntimeWarning) -load_dotenv() -redishost=os.getenv("REDISHOST") -redisport=os.getenv('REDISPORT') -redisdb=os.getenv('REDISDB') -redispassword=os.getenv('REDISPASSWORD') -counter = COUNTState(limit=3) - -# Update loop count in workflow -async def update_loop_count(state): - """Update loop counter""" - current_count = state.get("loop_count", 0) - return {"loop_count": current_count + 1} -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - messages = state["messages"] +from app.db import get_db +from app.services.memory_config_service import MemoryConfigService - # Add boundary check - if not messages: - return END - counter.add(1) # Increment by 1 +from app.core.memory.agent.utils.llm_tools import ReadState +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node +from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( + Split_The_Problem, + Problem_Extension, +) +from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( + retrieve, +) +from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( + Input_Summary, + Retrieve_Summary, + Summary_fails, + Summary, +) +from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify +from app.core.memory.agent.langgraph_graph.routing.routers import ( + Split_continue, + Retrieve_continue, + Verify_continue, +) - loop_count = counter.get_total() - logger.debug(f"[should_continue] Current loop count: {loop_count}") - - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"Status tools: {status_tools}") - - if "success" in status_tools: - counter.reset() - return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Maximum loop count is 3 - return "content_input" - else: - counter.reset() - return "Summary_fails" - else: - # Add default return value to avoid returning None - counter.reset() - return "Summary" # Default based on business requirements - - -def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Add default return value to avoid returning None - return 'Retrieve_Summary' # Default based on business logic - - -def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - logger.debug(f"Split_continue state: {state}") - - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '2': - return 'Input_Summary' - return 'Split_The_Problem' # Default case - - -class ProblemExtensionNode: - def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""): - self.tool_node = ToolNode([tool]) - self.id = id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - - async def __call__(self, state): - messages = state["messages"] - last_message = messages[-1] if messages else "" - logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}") - if self.tool_name == 'Input_Summary': - tool_call = re.findall("'id': '(.*?)'", str(last_message))[0] - else: - tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1] - - # Try to extract actual content payload from previous tool result - raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message) - extracted_payload = None - # Capture ToolMessage content field (supports single/double quotes), avoid greedy matching - m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S) - if m: - extracted_payload = m.group(1) - else: - # Fallback: use raw string directly - extracted_payload = raw_msg - - # Try to parse content as JSON first - try: - content = json.loads(extracted_payload) - except Exception: - # Try to extract JSON fragment from text and parse - parsed = None - candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S) - for cand in candidates: - try: - parsed = json.loads(cand) - break - except Exception: - continue - # If still fails, use raw string as content - content = parsed if parsed is not None else extracted_payload - - # Build correct parameters based on tool name - tool_args = {} - - if self.tool_name == "Verify": - # Verify tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Retrieve": - # Retrieve tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary": - # Summary tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary_fails": - # Summary_fails tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == 'Input_Summary': - tool_args["context"] = str(last_message) - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - tool_args["storage_type"] = getattr(self, 'storage_type', "") - tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "") - elif self.tool_name == 'Retrieve_Summary': - # Retrieve_Summary expects dict directly, not JSON string - # content might be a JSON string, try to parse it - if isinstance(content, str): - try: - parsed_content = json.loads(content) - # Check if it has a "context" key - if isinstance(parsed_content, dict) and "context" in parsed_content: - tool_args["context"] = parsed_content["context"] - else: - tool_args["context"] = parsed_content - except json.JSONDecodeError: - # If parsing fails, wrap the string - tool_args["context"] = {"content": content} - elif isinstance(content, dict): - # Check if content has a "context" key that needs unwrapping - if "context" in content: - tool_args["context"] = content["context"] - else: - tool_args["context"] = content - else: - tool_args["context"] = {"content": str(content)} - - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - else: - # Other tools use context parameter - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - - - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": self.id + f"{tool_call}", - }] - ) - ] - } - result = await self.tool_node.ainvoke(tool_input) - result_text = str(result) - - return {"messages": [AIMessage(content=result_text)]} @asynccontextmanager -async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None): - """ - Create a read graph workflow for memory operations. - - Args: - namespace: Namespace identifier - tools: MCP tools loaded from session - search_switch: Search mode switch ("0", "1", or "2") - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type (optional) - user_rag_memory_id: User RAG memory ID (optional) - """ - memory = InMemorySaver() - tool = [i.name for i in tools] - logger.info(f"Initializing read graph with tools: {tool}") - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - # Extract tool functions - Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None) - Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None) - Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None) - Verify_ = next((t for t in tools if t.name == "Verify"), None) - Summary_ = next((t for t in tools if t.name == "Summary"), None) - Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None) - Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None) - Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None) - - # Instantiate services - parameter_builder = ParameterBuilder() - multimodal_processor = MultimodalProcessor() - - # Create nodes using new modular components - Split_The_Problem_node = ToolNode([Split_The_Problem_]) - - Problem_Extension_node = ToolExecutionNode( - tool=Problem_Extension_, - node_id="Problem_Extension_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, +async def make_read_graph(): + """创建并返回 LangGraph 工作流""" + try: + # Build workflow graph + workflow = StateGraph(ReadState) + workflow.add_node("content_input", content_input_node) + workflow.add_node("Split_The_Problem", Split_The_Problem) + workflow.add_node("Problem_Extension", Problem_Extension) + workflow.add_node("Input_Summary", Input_Summary) + # workflow.add_node("Retrieve", retrieve_nodes) + workflow.add_node("Retrieve", retrieve) + workflow.add_node("Verify", Verify) + workflow.add_node("Retrieve_Summary", Retrieve_Summary) + workflow.add_node("Summary", Summary) + workflow.add_node("Summary_fails", Summary_fails) + + # 添加边 + workflow.add_edge(START, "content_input") + workflow.add_conditional_edges("content_input", Split_continue) + workflow.add_edge("Input_Summary", END) + workflow.add_edge("Split_The_Problem", "Problem_Extension") + workflow.add_edge("Problem_Extension", "Retrieve") + workflow.add_conditional_edges("Retrieve", Retrieve_continue) + workflow.add_edge("Retrieve_Summary", END) + workflow.add_conditional_edges("Verify", Verify_continue) + workflow.add_edge("Summary_fails", END) + workflow.add_edge("Summary", END) + + + '''-----''' + # workflow.add_edge("Retrieve", END) + + # 编译工作流 + graph = workflow.compile() + yield graph + + except Exception as e: + print(f"创建工作流失败: {e}") + raise + finally: + print("工作流创建完成") + +async def main(): + """主函数 - 运行工作流""" + message = "昨天有什么好看的电影" + group_id = '88a459f5_text09' # 组ID + storage_type = 'neo4j' # 存储类型 + search_switch = '1' # 搜索开关 + user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" ) + import time + start=time.time() + try: + async with make_read_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id + ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + print(f"处理节点: {node_name}") + + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - Retrieve_node = ToolExecutionNode( - tool=Retrieve_, - node_id="Retrieve_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) + + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) + + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) + + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - Verify_node = ToolExecutionNode( - tool=Verify_, - node_id="Verify_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - - Summary_node = ToolExecutionNode( - tool=Summary_, - node_id="Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - Summary_fails_node = ToolExecutionNode( - tool=Summary_fails_, - node_id="Summary_fails_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + # # 过滤掉空值 + # _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] + # + # # 优化搜索结果 + # print("=== 开始优化搜索结果 ===") + # optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + # result=reorder_output_results(optimized_outputs) + # # 保存优化后的结果到文件 + # with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f: + # import json + # f.write(json.dumps(result, indent=4, ensure_ascii=False)) + # + print(f"=== 最终摘要 ===") + print(summary) + + except Exception as e: + import traceback + traceback.print_exc() - Retrieve_Summary_node = ToolExecutionNode( - tool=Retrieve_Summary_, - node_id="Retrieve_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + end=time.time() + print(100*'y') + print(f"总耗时: {end-start}s") + print(100*'y') - Input_Summary_node = ToolExecutionNode( - tool=Input_Summary_, - node_id="Input_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - async def content_input_node(state): - state_search_switch = state.get("search_switch", search_switch) - - tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem" - session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id" - - return await create_input_message( - state=state, - tool_name=tool_name, - session_id=f"{session_prefix}_{namespace}", - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - multimodal_processor=multimodal_processor, - memory_config=memory_config, - ) - - - # Build workflow graph - workflow = StateGraph(ReadState) - workflow.add_node("content_input", content_input_node) - workflow.add_node("Split_The_Problem", Split_The_Problem_node) - workflow.add_node("Problem_Extension", Problem_Extension_node) - workflow.add_node("Retrieve", Retrieve_node) - workflow.add_node("Verify", Verify_node) - workflow.add_node("Summary", Summary_node) - workflow.add_node("Summary_fails", Summary_fails_node) - workflow.add_node("Retrieve_Summary", Retrieve_Summary_node) - workflow.add_node("Input_Summary", Input_Summary_node) - - # Add edges using imported routers - workflow.add_edge(START, "content_input") - workflow.add_conditional_edges("content_input", Split_continue) - workflow.add_edge("Input_Summary", END) - workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") - workflow.add_conditional_edges("Retrieve", Retrieve_continue) - workflow.add_edge("Retrieve_Summary", END) - workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) - workflow.add_edge("Summary", END) - - graph = workflow.compile(checkpointer=memory) - yield graph +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py b/api/app/core/memory/agent/langgraph_graph/routing/__init__.py deleted file mode 100644 index a9366bd0..00000000 --- a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph routing logic.""" - -from app.core.memory.agent.langgraph_graph.routing.routers import ( - Verify_continue, - Retrieve_continue, - Split_continue, -) - -__all__ = [ - "Verify_continue", - "Retrieve_continue", - "Split_continue", -] diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index c8abd544..c0b01be1 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -1,123 +1,62 @@ -""" -Routing functions for LangGraph conditional edges. -This module provides routing functions that determine the next node to execute -based on state values. All functions return Literal types for type safety. -""" - -import logging -import re from typing import Literal -from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch +from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState -logger = logging.getLogger(__name__) -# Global counter for Verify routing +logger = get_agent_logger(__name__) counter = COUNTState(limit=3) - - -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: +def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: """ - Determine routing after Verify node based on verification result. - - This function checks the verification result in the last message and routes to: - - Summary: if verification succeeded - - content_input: if verification failed and retry limit not reached - - Summary_fails: if verification failed and retry limit reached - + Determine routing based on search_switch value. + Args: - state: LangGraph state containing messages - + state: State dictionary containing search_switch + Returns: - Next node name as Literal type + Next node to execute """ - messages = state.get("messages", []) - - # Boundary check - if not messages: - logger.warning("[Verify_continue] No messages in state, defaulting to Summary") - counter.reset() - return "Summary" - - # Increment counter - counter.add(1) + logger.debug(f"Split_continue state: {state}") + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '2': + return 'Input_Summary' + return 'Split_The_Problem' # 默认情况 + +def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: + """ + Determine routing based on search_switch value. + + Args: + state: State dictionary containing search_switch + + Returns: + Next node to execute + """ + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '0': + return 'Verify' + elif search_switch == '1': + return 'Retrieve_Summary' + return 'Retrieve_Summary' # Default based on business logic +def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: + status=state.get('verify', '')['status'] loop_count = counter.get_total() - logger.debug(f"[Verify_continue] Current loop count: {loop_count}") - - # Extract verification result from last message - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"[Verify_continue] Status tools: {status_tools}") - - # Route based on verification result - if "success" in status_tools: + print(status) + if "success" in status: counter.reset() return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Max retry count is 2 + elif "failed" in status: + if loop_count < 2: # Maximum loop count is 3 return "content_input" else: counter.reset() return "Summary_fails" - else: - # Default to Summary if status is unclear - counter.reset() - return "Summary" - - -def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing after Retrieve node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '0': Route to Verify (verification needed) - - search_switch == '1': Route to Retrieve_Summary (direct summary) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - search_switch = extract_search_switch(state) - - logger.debug(f"[Retrieve_continue] search_switch: {search_switch}") - - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Default to Retrieve_Summary - logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary") - return 'Retrieve_Summary' - - -def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing after content_input node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '2': Route to Input_Summary (direct input summary) - - Otherwise: Route to Split_The_Problem (problem decomposition) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - logger.debug(f"[Split_continue] state keys: {state.keys()}") - - search_switch = extract_search_switch(state) - - logger.debug(f"[Split_continue] search_switch: {search_switch}") - - if search_switch == '2': - return 'Input_Summary' - - # Default to Split_The_Problem - return 'Split_The_Problem' + # else: + # # Add default return value to avoid returning None + # counter.reset() + # return "Summary" # Default based on business requirements diff --git a/api/app/core/memory/agent/langgraph_graph/state/__init__.py b/api/app/core/memory/agent/langgraph_graph/state/__init__.py deleted file mode 100644 index 279c6463..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph state management utilities.""" - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_search_switch, - extract_tool_call_id, - extract_content_payload, -) - -__all__ = [ - "extract_search_switch", - "extract_tool_call_id", - "extract_content_payload", -] diff --git a/api/app/core/memory/agent/langgraph_graph/state/extractors.py b/api/app/core/memory/agent/langgraph_graph/state/extractors.py deleted file mode 100644 index f5a32f5d..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/extractors.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -State extraction utilities for type-safe access to LangGraph state values. - -This module provides utility functions for extracting values from LangGraph state -dictionaries with proper error handling and sensible defaults. -""" - -import json -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - -def extract_search_switch(state: dict) -> Optional[str]: - """ - Extract search_switch from state or messages. - """ - - search_switch = state.get("search_switch") - - if search_switch is not None: - return str(search_switch) - - # Try to extract from messages - messages = state.get("messages", []) - if not messages: - return None - - # 从最新的消息开始查找 - for message in reversed(messages): - # 尝试从 tool_calls 中提取 - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - # 从 tool_call 的 args 中提取 - if "args" in tool_call and isinstance(tool_call["args"], dict): - search_switch = tool_call["args"].get("search_switch") - if search_switch is not None: - return str(search_switch) - # 直接从 tool_call 中提取 - search_switch = tool_call.get("search_switch") - if search_switch is not None: - return str(search_switch) - - # 尝试从 content 中提取(如果是 JSON 格式) - if hasattr(message, "content"): - try: - import json - if isinstance(message.content, str): - content_data = json.loads(message.content) - if isinstance(content_data, dict): - search_switch = content_data.get("search_switch") - if search_switch is not None: - return str(search_switch) - except (json.JSONDecodeError, ValueError): - pass - - return None - - -def extract_tool_call_id(message: Any) -> str: - """ - Extract tool call ID from message using structured attributes. - - This function extracts the tool call ID from a message object, handling both - direct attribute access and tool_calls list structures. - - Args: - message: Message object (typically ToolMessage or AIMessage) - - Returns: - Tool call ID as string - - Raises: - ValueError: If tool call ID cannot be extracted - - Examples: - >>> message = ToolMessage(content="...", tool_call_id="call_123") - >>> extract_tool_call_id(message) - 'call_123' - """ - # Try direct attribute access for ToolMessage - if hasattr(message, "tool_call_id"): - tool_call_id = message.tool_call_id - if tool_call_id: - return str(tool_call_id) - - # Try extracting from tool_calls list for AIMessage - if hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "id" in tool_call: - return str(tool_call["id"]) - - # Try extracting from id attribute - if hasattr(message, "id"): - message_id = message.id - if message_id: - return str(message_id) - - # If all else fails, raise an error - raise ValueError(f"Could not extract tool call ID from message: {type(message)}") - - -def extract_content_payload(message: Any) -> Any: - """ - Extract content payload from ToolMessage, parsing JSON if needed. - - This function extracts the content from a message and attempts to parse it as JSON - if it appears to be a JSON string. It handles various message formats and provides - sensible fallbacks. - - Args: - message: Message object (typically ToolMessage) - - Returns: - Parsed content (dict, list, or str) - - Examples: - >>> message = ToolMessage(content='{"key": "value"}') - >>> extract_content_payload(message) - {'key': 'value'} - - >>> message = ToolMessage(content='plain text') - >>> extract_content_payload(message) - 'plain text' - """ - # Extract raw content - # For ToolMessages (responses from tools), extract from content - if hasattr(message, "content"): - raw_content = message.content - logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}") - - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(raw_content, list): - for block in raw_content: - if isinstance(block, dict) and block.get('type') == 'text': - raw_content = block.get('text', '') - logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}") - break - - # If content is empty and this is an AIMessage with tool_calls, - # extract from args (this handles the initial tool call from content_input) - if not raw_content and hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "args" in tool_call: - return tool_call["args"] - else: - raw_content = str(message) - - # If content is already a dict or list, return it directly - if isinstance(raw_content, (dict, list)): - logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}") - return raw_content - - # Try to parse as JSON - if isinstance(raw_content, str): - # First, try direct JSON parsing - try: - parsed = json.loads(raw_content) - logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - pass - - # If that fails, try to extract JSON from the string - # This handles cases where the content is embedded in a larger string - import re - json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL) - for candidate in json_candidates: - try: - parsed = json.loads(candidate) - logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - continue - - # If all parsing attempts fail, return the raw content - logger.info(f"extract_content_payload: returning raw content (parsing failed)") - return raw_content diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py new file mode 100644 index 00000000..ce6d5dd4 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -0,0 +1,320 @@ +import asyncio +import json +from datetime import datetime, timedelta + + +from langchain.tools import tool +from pydantic import BaseModel, Field + + +from app.core.memory.src.search import ( + search_by_temporal, + search_by_keyword_temporal, +) + +def extract_tool_message_content(response): + """从agent响应中提取ToolMessage内容和工具名称""" + messages = response.get('messages', []) + + for message in messages: + if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): + # 这是一个ToolMessage + tool_content = message.content + tool_name = None + + # 尝试获取工具名称 + if hasattr(message, 'name'): + tool_name = message.name + elif hasattr(message, 'tool_name'): + tool_name = message.tool_name + + try: + # 解析JSON内容 + parsed_content = json.loads(tool_content) + return { + 'tool_name': tool_name, + 'content': parsed_content + } + except json.JSONDecodeError: + # 如果不是JSON格式,直接返回内容 + return { + 'tool_name': tool_name, + 'content': tool_content + } + + return None + + +class TimeRetrievalInput(BaseModel): + """时间检索工具的输入模式""" + context: str = Field(description="用户输入的查询内容") + group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + +def create_time_retrieval_tool(group_id: str): + """ + 创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + """ + + def clean_temporal_result_fields(data): + """ + 清理时间搜索结果中不需要的字段,并修改结构 + + Args: + data: 要清理的数据 + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', + 'valid_at', 'invalid_at', 'statement_ids' + } + + if isinstance(data, dict): + cleaned = {} + for key, value in data.items(): + if key == 'statements' and isinstance(value, dict) and 'statements' in value: + # 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]} + cleaned_value = clean_temporal_result_fields(value) + # 进一步将内部的 statements 改为 time_search + if 'statements' in cleaned_value: + cleaned['results'] = { + 'time_search': cleaned_value['statements'] + } + else: + cleaned['results'] = cleaned_value + elif key not in fields_to_remove: + cleaned[key] = clean_temporal_result_fields(value) + return cleaned + elif isinstance(data, list): + return [clean_temporal_result_fields(item) for item in data] + else: + return data + + @tool + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: + """ + 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询上下文内容 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - group_id_param: 组ID(可选,用于覆盖默认组ID) + - clean_output: 是否清理输出中的元数据字段 + -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + # 使用传入的参数或默认值 + actual_group_id = group_id_param or group_id + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + + # 基本时间搜索 + results = await search_by_temporal( + group_id=actual_group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=10 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + @tool + def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: + """ + 优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询内容 + - days_back: 向前搜索的天数,默认7天 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - clean_output: 是否清理输出中的元数据字段 + - end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") + + # 关键词时间搜索 + results = await search_by_keyword_temporal( + query_text=context, + group_id=group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=15 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + return TimeRetrievalWithGroupId + + +def create_hybrid_retrieval_tool_async(memory_config, **search_params): + """ + 创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数,包含group_id, limit, include等 + """ + + def clean_result_fields(data): + """ + 递归清理结果中不需要的字段 + + Args: + data: 要清理的数据(可能是字典、列表或其他类型) + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + } + + if isinstance(data, dict): + # 对字典进行清理 + cleaned = {} + for key, value in data.items(): + if key not in fields_to_remove: + cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 + return cleaned + elif isinstance(data, list): + # 对列表中的每个元素进行清理 + return [clean_result_fields(item) for item in data] + else: + # 其他类型直接返回 + return data + + @tool + async def HybridSearch( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + rerank_alpha: float = 0.6, + use_forgetting_rerank: bool = False, + use_llm_rerank: bool = False, + clean_output: bool = True # 新增:是否清理输出字段 + ) -> str: + """ + 优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + rerank_alpha: 重排序权重参数 + use_forgetting_rerank: 是否使用遗忘重排序 + use_llm_rerank: 是否使用LLM重排序 + clean_output: 是否清理输出中的元数据字段 + """ + try: + # 导入run_hybrid_search函数 + from app.core.memory.src.search import run_hybrid_search + + # 合并参数,优先使用传入的参数 + final_params = { + "query_text": context, + "search_type": search_type, + "group_id": group_id or search_params.get("group_id"), + "limit": limit or search_params.get("limit", 10), + "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), + "output_path": None, # 不保存到文件 + "memory_config": memory_config, + "rerank_alpha": rerank_alpha, + "use_forgetting_rerank": use_forgetting_rerank, + "use_llm_rerank": use_llm_rerank + } + + # 执行混合检索 + raw_results = await run_hybrid_search(**final_params) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_result_fields(raw_results) + else: + cleaned_results = raw_results + + # 格式化返回结果 + formatted_results = { + "search_query": context, + "search_type": search_type, + "results": cleaned_results + } + + return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) + + except Exception as e: + error_result = { + "error": f"混合检索失败: {str(e)}", + "search_query": context, + "search_type": search_type, + "timestamp": datetime.now().isoformat() + } + return json.dumps(error_result, ensure_ascii=False, indent=2) + + return HybridSearch + + +def create_hybrid_retrieval_tool_sync(memory_config, **search_params): + """ + 创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数 + """ + @tool + def HybridSearchSync( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + clean_output: bool = True + ) -> str: + """ + 优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + clean_output: 是否清理输出中的元数据字段 + """ + async def _async_search(): + # 创建异步工具并执行 + async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) + return await async_tool.ainvoke({ + "context": context, + "search_type": search_type, + "limit": limit, + "group_id": group_id, + "clean_output": clean_output + }) + + return asyncio.run(_async_search()) + + return HybridSearchSync \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index ae333e84..5a6f1e28 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,30 +1,32 @@ + import asyncio -import json import sys import warnings from contextlib import asynccontextmanager -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage + +from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode + + +from app.db import get_db +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write +from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) - logger = get_agent_logger(__name__) if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - @asynccontextmanager -async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig): +async def make_write_graph(): """ Create a write graph workflow for memory operations. - + Args: user_id: User identifier tools: MCP tools loaded from session @@ -32,43 +34,8 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me group_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - logger.info("Loading MCP tools: %s", [t.name for t in tools]) - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - data_write_tool = next((t for t in tools if t.name == "Data_write"), None) - - if not data_write_tool: - logger.error("Data_write tool not found", exc_info=True) - raise ValueError("Data_write tool not found") - - write_node = ToolNode([data_write_tool]) - - async def call_model(state): - messages = state["messages"] - last_message = messages[-1] - content = last_message[1] if isinstance(last_message, tuple) else last_message.content - - # Call Data_write directly with memory_config - write_params = { - "content": content, - "apply_id": apply_id, - "group_id": group_id, - "user_id": user_id, - "memory_config": memory_config, - } - logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}") - - write_result = await data_write_tool.ainvoke(write_params) - - if isinstance(write_result, dict): - result_content = write_result.get("data", str(write_result)) - else: - result_content = str(write_result) - logger.info("Write content: %s", result_content) - return {"messages": [AIMessage(content=result_content)]} - workflow = StateGraph(WriteState) - workflow.add_node("content_input", call_model) + workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "content_input") workflow.add_edge("content_input", "save_neo4j") @@ -76,5 +43,45 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me graph = workflow.compile() - yield graph + + +async def main(): + """主函数 - 运行工作流""" + message = "今天周一" + group_id = 'new_2025test1103' # 组ID + + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" + ) + try: + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j'==node_name: + massages=node_data + massages=massages.get('write_result')['status'] + print(massages) # | 更新数据: {node_data} + + except Exception as e: + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/__init__.py b/api/app/core/memory/agent/mcp_server/__init__.py deleted file mode 100644 index efd03773..00000000 --- a/api/app/core/memory/agent/mcp_server/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -MCP Server package for memory agent. - -This package provides the FastMCP server implementation with context-based -dependency injection for tool functions. - -Package structure: -- server: FastMCP server initialization and context setup -- tools: MCP tool implementations -- models: Pydantic response models -- services: Business logic services -""" -# from app.core.memory.agent.mcp_server.server import ( -# mcp, -# initialize_context, -# main, -# get_context_resource -# ) - -# # Import tools to register them (but don't export them) -# from app.core.memory.agent.mcp_server import tools - -# __all__ = [ -# 'mcp', -# 'initialize_context', -# 'main', -# 'get_context_resource', -# ] \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/mcp_instance.py b/api/app/core/memory/agent/mcp_server/mcp_instance.py deleted file mode 100644 index 3a2eeb78..00000000 --- a/api/app/core/memory/agent/mcp_server/mcp_instance.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -MCP Server Instance - -This module contains the FastMCP server instance that is shared across all modules. -It's in a separate file to avoid circular import issues. -""" -from mcp.server.fastmcp import FastMCP - -# Initialize FastMCP server instance -# This instance is shared across all tool modules -mcp = FastMCP('data_flow') diff --git a/api/app/core/memory/agent/mcp_server/server.py b/api/app/core/memory/agent/mcp_server/server.py deleted file mode 100644 index 26f24824..00000000 --- a/api/app/core/memory/agent/mcp_server/server.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -MCP Server initialization with FastMCP context setup. - -This module initializes the FastMCP server and registers shared resources -in the context for dependency injection into tool functions. -""" -import os -import sys - -from app.core.config import settings -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.services.search_service import SearchService -from app.core.memory.agent.mcp_server.services.session_service import SessionService -from app.core.memory.agent.mcp_server.services.template_service import TemplateService -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.redis_tool import store - -logger = get_agent_logger(__name__) - - -def get_context_resource(ctx, resource_name: str): - """ - Helper function to retrieve a resource from the FastMCP context. - - Args: - ctx: FastMCP Context object (passed to tool functions) - resource_name: Name of the resource to retrieve - - Returns: - The requested resource - - Raises: - AttributeError: If the resource doesn't exist - - Example: - @mcp.tool() - async def my_tool(ctx: Context): - template_service = get_context_resource(ctx, 'template_service') - llm_client = get_context_resource(ctx, 'llm_client') - """ - if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None: - raise RuntimeError("Context does not have fastmcp attribute") - - if not hasattr(ctx.fastmcp, resource_name): - raise AttributeError( - f"Resource '{resource_name}' not found in context. " - f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}" - ) - - return getattr(ctx.fastmcp, resource_name) - - -def initialize_context(): - """ - Initialize and register shared resources in FastMCP context. - - This function sets up all shared resources that will be available - to tool functions via dependency injection through the context parameter. - - Resources are stored as attributes on the FastMCP instance and can be - accessed via ctx.fastmcp in tool functions. - - Resources registered: - - session_store: RedisSessionStore for session management - - llm_client: LLM client for structured API calls - - app_settings: Application settings (renamed to avoid conflict with FastMCP settings) - - template_service: Service for template rendering - - search_service: Service for hybrid search - - session_service: Service for session operations - """ - try: - # Register Redis session store - logger.info("Registering session_store in context") - mcp.session_store = store - - # Note: LLM client is NOT loaded at server startup - # It should be loaded dynamically when needed, with config_id passed explicitly - # to make_write_graph or make_read_graph functions - logger.info("LLM client will be loaded dynamically with config_id when needed") - mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id - - # Register application settings (renamed to avoid conflict with FastMCP's settings) - logger.info("Registering app_settings in context") - mcp.app_settings = settings - - # Register template service - template_root = PROJECT_ROOT_ + '/agent/utils/prompt' - # logger.info(f"Registering template_service in context with root: {template_root}") - template_service = TemplateService(template_root) - mcp.template_service = template_service - - # Register search service - # logger.info("Registering search_service in context") - search_service = SearchService() - mcp.search_service = search_service - - # Register session service - # logger.info("Registering session_service in context") - session_service = SessionService(store) - mcp.session_service = session_service - - # logger.info("All context resources registered successfully") - - except Exception as e: - logger.error(f"Failed to initialize context: {e}", exc_info=True) - raise - - -def main(): - """ - Main entry point for the MCP server. - - Initializes context and starts the server with SSE transport. - """ - try: - logger.info("Starting MCP server initialization") - # Initialize context resources - initialize_context() - - # Import and register tools (imports trigger tool registration) - from app.core.memory.agent.mcp_server.tools import ( # noqa: F401 - data_tools, - problem_tools, - retrieval_tools, - summary_tools, - verification_tools, - ) - - # Tools are registered via imports above - - # Get MCP port from environment (default: 8081) - mcp_port = int(os.getenv("MCP_PORT", "8081")) - logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Configure DNS rebinding protection for Docker container compatibility - from mcp.server.fastmcp.server import TransportSecuritySettings - - # Disable DNS rebinding protection to allow Docker container hostnames - # This allows containers to connect using service names like 'mcp-server' - mcp.settings.transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=False, - ) - logger.info("DNS rebinding protection: disabled for Docker container compatibility") - - # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Run the server with SSE transport for HTTP connections - import uvicorn - app = mcp.sse_app() - uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info") - - except Exception as e: - logger.error(f"Failed to start MCP server: {e}", exc_info=True) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/agent/mcp_server/tools/__init__.py b/api/app/core/memory/agent/mcp_server/tools/__init__.py deleted file mode 100644 index 5ce04ef3..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -MCP Tools module. - -This module contains all MCP tool implementations organized by functionality. - -Tools are organized into the following modules: -- problem_tools: Question segmentation and extension -- retrieval_tools: Database and context retrieval -- verification_tools: Data verification -- summary_tools: Summarization and summary retrieval -- data_tools: Data type differentiation and writing -""" - -# Import all tool modules to register them with the MCP server -from . import problem_tools -from . import retrieval_tools -from . import verification_tools -from . import summary_tools -from . import data_tools - -__all__ = [ - 'problem_tools', - 'retrieval_tools', - 'verification_tools', - 'summary_tools', - 'data_tools', -] diff --git a/api/app/core/memory/agent/mcp_server/tools/data_tools.py b/api/app/core/memory/agent/mcp_server/tools/data_tools.py deleted file mode 100644 index 631f7fd7..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/data_tools.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Data Tools for data type differentiation and writing. - -This module contains MCP tools for distinguishing data types and writing data. -""" - -import os - -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.retrieval_models import ( - DistinguishTypeResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.write_tools import write -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Data_type_differentiation( - ctx: Context, - context: str, - memory_config: MemoryConfig, -) -> dict: - """ - Distinguish the type of data (read or write). - - Args: - ctx: FastMCP context for dependency injection - context: Text to analyze for type differentiation - memory_config: MemoryConfig object containing LLM configuration - - Returns: - dict: Contains 'context' with the original text and 'type' field - """ - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - - # Get LLM client from memory_config using factory pattern - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='distinguish_types_prompt.jinja2', - operation_name='status_typle', - user_query=context - ) - except Exception as e: - logger.error( - f"Template rendering failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "type": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=DistinguishTypeResponse - ) - - result = structured.model_dump() - - # Add context to result - result["context"] = context - - return result - - except Exception as e: - logger.error( - f"LLM call failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": f"LLM call failed: {str(e)}" - } - - except Exception as e: - logger.error( - f"Data_type_differentiation failed: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": str(e) - } - - -@mcp.tool() -async def Data_write( - ctx: Context, - content: str, - user_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Write data to the database/file system. - - Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'status', 'saved_to', and 'data' fields - """ - try: - # Ensure output directory exists - os.makedirs("data_output", exist_ok=True) - file_path = os.path.join("data_output", "user_data.csv") - - # Write data - clients are constructed inside write() from memory_config - await write( - content=content, - user_id=user_id, - apply_id=apply_id, - group_id=group_id, - memory_config=memory_config, - ) - logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - - return { - "status": "success", - "saved_to": file_path, - "data": content, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name, - } - - except Exception as e: - logger.error(f"Data_write failed: {e}", exc_info=True) - return { - "status": "error", - "message": str(e), - } diff --git a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py b/api/app/core/memory/agent/mcp_server/tools/problem_tools.py deleted file mode 100644 index 49812e38..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Problem Tools for question segmentation and extension. - -This module contains MCP tools for breaking down and extending user questions. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.problem_models import ( - ProblemBreakdownResponse, - ProblemExtensionResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Split_The_Problem( - ctx: Context, - sentence: str, - sessionid: str, - messages_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Segment the dialogue or sentence into sub-problems. - - Args: - ctx: FastMCP context for dependency injection - sentence: Original sentence to split - sessionid: Session identifier - messages_id: Message identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'context' (JSON string of split results) and 'original' sentence - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Extract user ID from session - user_id = session_service.resolve_user_id(sessionid) - - # Get conversation history - history = await session_service.get_history(user_id, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='problem_breakdown_prompt.jinja2', - operation_name='split_the_problem', - history=history, - sentence=sentence - ) - except Exception as e: - logger.error( - f"Template rendering failed for Split_The_Problem: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemBreakdownResponse - ) - - # Handle RootModel response with .root attribute access - if structured is None: - # LLM returned None, use empty list as fallback - split_result = json.dumps([], ensure_ascii=False) - elif hasattr(structured, 'root') and structured.root is not None: - split_result = json.dumps( - [item.model_dump() for item in structured.root], - ensure_ascii=False - ) - elif isinstance(structured, list): - # Fallback: treat structured itself as the list - split_result = json.dumps( - [item.model_dump() for item in structured], - ensure_ascii=False - ) - else: - # Last resort: use empty list - split_result = json.dumps([], ensure_ascii=False) - - except Exception as e: - logger.error( - f"LLM call failed for Split_The_Problem: {e}", - exc_info=True - ) - split_result = json.dumps([], ensure_ascii=False) - - logger.info("Problem splitting") - logger.info(f"Problem split result: {split_result}") - - # Emit intermediate output for frontend - result = { - "context": split_result, - "original": sentence, - "_intermediate": { - "type": "problem_split", - "data": json.loads(split_result) if split_result else [], - "original_query": sentence - } - } - - return result - - except Exception as e: - logger.error( - f"Split_The_Problem failed: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem splitting', duration) - - -@mcp.tool() -async def Problem_Extension( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Extend the problem with additional sub-questions. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing split problem results - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'context' (aggregated questions) and 'original' question - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID from usermessages - from app.core.memory.agent.utils.messages_tool import Resolve_username - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Process context to extract questions - extent_quest, original = await Problem_Extension_messages_deal(context) - - # Format questions for template rendering - questions_formatted = [] - for msg in extent_quest: - if msg.get("role") == "user": - questions_formatted.append(msg.get("content", "")) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='Problem_Extension_prompt.jinja2', - operation_name='problem_extension', - history=history, - questions=questions_formatted - ) - except Exception as e: - logger.error( - f"Template rendering failed for Problem_Extension: {e}", - exc_info=True - ) - return { - "context": {}, - "original": original, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - response_content = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemExtensionResponse - ) - - # Aggregate results by original question - aggregated_dict = {} - for item in response_content.root: - key = getattr(item, "original_question", None) or ( - item.get("original_question") if isinstance(item, dict) else None - ) - value = getattr(item, "extended_question", None) or ( - item.get("extended_question") if isinstance(item, dict) else None - ) - if not key or not value: - continue - aggregated_dict.setdefault(key, []).append(value) - - except Exception as e: - logger.error( - f"LLM call failed for Problem_Extension: {e}", - exc_info=True - ) - aggregated_dict = {} - - logger.info("Problem extension") - logger.info(f"Problem extension result: {aggregated_dict}") - - # Emit intermediate output for frontend - result = { - "context": aggregated_dict, - "original": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "problem_extension", - "data": aggregated_dict, - "original_query": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - return result - - except Exception as e: - logger.error( - f"Problem_Extension failed: {e}", - exc_info=True - ) - return { - "context": {}, - "original": context.get("original", ""), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem extension', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py b/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py deleted file mode 100644 index db18ba04..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Retrieval Tools for database and context retrieval. - -This module contains MCP tools for retrieving data using hybrid search. -""" - -import os -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import ( - deduplicate_entries, - merge_to_key_value_pairs, -) -from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal -from app.core.rag.nlp.search import knowledge_retrieval -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Retrieve( - ctx: Context, - context, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Retrieve data from the database using hybrid search. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary or string containing query information - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'context' with Query and Expansion_issue results - """ - kb_config = { - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id": os.getenv('reranker_id'), - "reranker_top_k": 10 - } - start = time.time() - logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}") - - try: - # Extract services from context - search_service = get_context_resource(ctx, 'search_service') - - databases_anser = [] - - # Handle both dict and string context - if isinstance(context, dict): - # Process dict context with extended questions - all_items = [] - logger.info(f"Retrieve: context keys={list(context.keys())}") - content, original = await Retriev_messages_deal(context) - logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}") - logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'") - - if not original: - logger.warning(f"Retrieve: original query is empty! context={context}") - - # Extract all query items from content - # content is like {original_question: [extended_questions...], ...} - for key, values in content.items(): - if isinstance(values, list): - all_items.extend(values) - elif isinstance(values, str): - all_items.append(values) - elif values is not None: - # Fallback: convert non-empty non-list values to string - all_items.append(str(values)) - - # Execute search for each question - for idx, question in enumerate(all_items): - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": question, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query=question - raw_results=clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results='' - cleaned_query = question - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - databases_anser.append({ - "Query_small": cleaned_query, - "Result_small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": idx + 1, - "total": len(all_items) - } - }) - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for question '{question}': {e}", - exc_info=True - ) - # Continue with empty result for this question - databases_anser.append({ - "Query_small": question, - "Result_small": "" - }) - - # Build initial database data structure - databases_data = { - "Query": original, - "Expansion_issue": databases_anser - } - - # Collect intermediate outputs before deduplication - intermediate_outputs = [] - for item in databases_anser: - if '_intermediate' in item: - intermediate_outputs.append(item['_intermediate']) - - # Deduplicate and merge results - deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) - deduplicated_data_merged = merge_to_key_value_pairs( - deduplicated_data, - 'Query_small', - 'Result_small' - ) - - # Restructure for Verify/Retrieve_Summary compatibility - keys, val = [], [] - for item in deduplicated_data_merged: - for items_key, items_value in item.items(): - keys.append(items_key) - val.append(items_value) - - send_verify = [] - for i, j in zip(keys, val, strict=False): - send_verify.append({ - "Query_small": i, - "Answer_Small": j - }) - - dup_databases = { - "Query": original, - "Expansion_issue": send_verify, - "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs - } - - logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - - else: - # Handle string context (simple query) - query = str(context).strip() - - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query = query - raw_results = clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results = '' - cleaned_query = query - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - # Keep structure for Verify/Retrieve_Summary compatibility - dup_databases = { - "Query": cleaned_query, - "Expansion_issue": [{ - "Query_small": cleaned_query, - "Answer_Small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": 1, - "total": 1 - } - }] - } - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for query '{query}': {e}", - exc_info=True - ) - # Return empty results on failure - dup_databases = { - "Query": query, - "Expansion_issue": [] - } - - logger.info( - f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, " - f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}" - ) - - # Build result with intermediate outputs - result = { - "context": dup_databases, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - # Add intermediate outputs list if they exist - intermediate_outputs = dup_databases.get('_intermediate_outputs', []) - if intermediate_outputs: - result['_intermediates'] = intermediate_outputs - logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result") - else: - logger.warning("No intermediate outputs found in dup_databases") - - return result - - except Exception as e: - logger.error( - f"Retrieve failed: {e}", - exc_info=True - ) - return { - "context": { - "Query": "", - "Expansion_issue": [] - }, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py deleted file mode 100644 index 0f306572..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ /dev/null @@ -1,640 +0,0 @@ -""" -Summary Tools for data summarization. - -This module contains MCP tools for summarizing retrieved data and generating responses. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import os -import re -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.summary_models import ( - RetrieveSummaryResponse, - SummaryResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Summary_messages_deal, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.rag.nlp.search import knowledge_retrieval -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Summary( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize the verified data. - - Args: - ctx: FastMCP context for dependency injection - context: JSON string containing verified data - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Process context to extract answer and query - answer_small, query = await Summary_messages_deal(context) - - - start_time= time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - data = { - "query": query, - "history": history, - "retrieve_info": answer_small - } - - except Exception as e: - logger.error( - f"Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='summary_prompt.jinja2', - operation_name='summary', - data=data, - query=query - ) - except Exception as e: - logger.error( - f"Template rendering failed for Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=SummaryResponse - ) - - aimessages = structured.query_answer or "" - - except Exception as e: - logger.error( - f"LLM call failed for Summary: {e}", - exc_info=True - ) - aimessages = "" - - try: - # Save session - if aimessages != "": - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}", - exc_info=True - ) - return { - "status": "error", - "message": str(e) - } - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after verification: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Summary', duration) - - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - -@mcp.tool() -async def Retrieve_Summary( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize data directly from retrieval results. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing Query and Expansion_issue from Retrieve - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - - - # Handle both 'content' and 'context' keys (LangGraph uses 'content') - logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}") - - if isinstance(context, dict): - if "content" in context: - inner = context["content"] - # If it's a JSON string, parse it - if isinstance(inner, str): - try: - parsed = json.loads(inner) - logger.info("Retrieve_Summary: successfully parsed JSON") - except json.JSONDecodeError: - # Try unescaping first - try: - unescaped = inner.encode('utf-8').decode('unicode_escape') - parsed = json.loads(unescaped) - logger.info("Retrieve_Summary: parsed after unescaping") - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error( - f"Retrieve_Summary: parsing failed even after unescape: {e}" - ) - context_dict = {"Query": "", "Expansion_issue": []} - parsed = None - - if parsed: - # Check if parsed has 'context' wrapper - if isinstance(parsed, dict) and "context" in parsed: - context_dict = parsed["context"] - else: - context_dict = parsed - elif isinstance(inner, dict): - context_dict = inner - else: - context_dict = {"Query": "", "Expansion_issue": []} - elif "context" in context: - context_dict = context["context"] if isinstance(context["context"], dict) else context - else: - context_dict = context - else: - context_dict = {"Query": "", "Expansion_issue": []} - - query = context_dict.get("Query", "") - expansion_issue = context_dict.get("Expansion_issue", []) - - logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}") - logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}") - - # Extract retrieve_info from expansion_issue - retrieve_info = [] - for item in expansion_issue: - # Check for both Answer_Small and Answer_Small (typo) for backward compatibility - answer = None - if isinstance(item, dict): - if "Answer_Small" in item: - answer = item["Answer_Small"] - - - if answer is not None: - # Handle both string and list formats - if isinstance(answer, list): - # Join list of characters/strings into a single string - retrieve_info.append(''.join(str(x) for x in answer)) - elif isinstance(answer, str): - retrieve_info.append(answer) - else: - retrieve_info.append(str(answer)) - - # Join all retrieve_info into a single string - retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else "" - - start_time=time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - except Exception as e: - logger.error( - f"Retrieve_Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='retrieve_summary', - query=query, - history=history, - retrieve_info=retrieve_info_str - ) - except Exception as e: - logger.error( - f"Template rendering failed for Retrieve_Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - - # Handle case where structured response might be None or incomplete - if structured and hasattr(structured, 'data') and structured.data: - aimessages = structured.data.query_answer or "" - else: - logger.warning("Structured response is None or incomplete, using default message") - aimessages = "信息不足,无法回答" - - - # Check for insufficient information response - if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="": - # Save session - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"Retrieve_Summary: LLM call failed: {e}", - exc_info=True - ) - aimessages = "" - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after retrieval: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval summary', duration) - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "retrieval_summary", - "summary": aimessages, - "query": query, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - -@mcp.tool() -async def Input_Summary( - ctx: Context, - context: str, - usermessages: str, - search_switch: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Generate a quick summary for direct input without verification. - - Args: - ctx: FastMCP context for dependency injection - context: String containing the input sentence - usermessages: User messages identifier - search_switch: Search switch value for routing ('2' for summaries only) - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'query_answer' with the summary result - """ - start = time.time() - logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - - try: - # Extract services from context - session_service = get_context_resource(ctx, "session_service") - search_service = get_context_resource(ctx, "search_service") - - # Resolve session ID - sessionid = Resolve_username(usermessages) or "" - sessionid = sessionid.replace('call_id_', '') - - start_time=time.time() - history = await session_service.get_history( - str(sessionid), - str(apply_id), - str(group_id) - ) - end_time=time.time() - logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}") - # Override with empty list for now (as in original) - - # Log the raw context for debugging - logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}") - - # Extract sentence from context - # Context can be a string or might contain the sentence in various formats - try: - # Try to parse as JSON first - if isinstance(context, str) and (context.startswith('{') or context.startswith('[')): - try: - import json - context_dict = json.loads(context) - if isinstance(context_dict, dict): - query = context_dict.get('sentence', context_dict.get('content', context)) - else: - query = context - except json.JSONDecodeError: - # Not valid JSON, try regex - match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context) - query = match.group(1) if match else context - else: - query = context - except Exception as e: - logger.warning(f"Failed to extract query from context: {e}") - query = context - - # Clean query - query = str(query).strip().strip("\"'") - - logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}") - - # Execute search based on search_switch and storage_type - try: - logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}") - - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - - # Retrieval - if search_switch == '2': - search_params["include"] = ["summaries"] - if storage_type == "rag" and user_rag_memory_id: - raw_results = [] - retrieve_info = "" - kb_config={ - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id":os.getenv('reranker_id'), - "reranker_top_k": 10 - } - - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - retrieve_info = '\n\n'.join(retrieval_knowledge) - raw_results=[retrieve_info] - logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}") - except: - retrieve_info='' - raw_results=[''] - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - logger.info("Input_Summary: Using summary for retrieval") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - except Exception as e: - logger.error( - f"Input_Summary: hybrid_search failed, using empty results: {e}", - exc_info=True - ) - retrieve_info, question, raw_results = "", query, [] - - # Return retrieved information directly without LLM processing - # Use the raw retrieved info as the answer - aimessages = retrieve_info if retrieve_info else "信息不足,无法回答" - - logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...") - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "input_summary", - "title": "快速答案", - "summary": aimessages, - "query": query, - "raw_results": raw_results, - "search_mode": "quick_search", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Input_Summary failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "信息不足,无法回答", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) - - -@mcp.tool() -async def Summary_fails( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Handle workflow failure when summary cannot be generated. - - Args: - ctx: FastMCP context for dependency injection - context: Failure context string - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'query_answer' with failure message - """ - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Parse session ID from usermessages - usermessages_parts = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages_parts[:-1]) - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - logger.info("没有相关数据") - logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}") - - return { - "status": "success", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - except Exception as e: - logger.error( - f"Summary_fails failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } diff --git a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py b/api/app/core/memory/agent/mcp_server/tools/verification_tools.py deleted file mode 100644 index cb6af5bd..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Verification Tools for data verification. - -This module contains MCP tools for verifying retrieved data. -""" -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Retrieve_verify_tool_messages_deal, - Verify_messages_deal, -) -from app.core.memory.agent.utils.verify_tool import VerifyTool -from app.schemas.memory_config_schema import MemoryConfig -from jinja2 import Template -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Verify( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Verify the retrieved data. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing query and expansion issues - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'verified_data' with verification results - """ - start = time.time() - - - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Load verification prompt template - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2' - - # Read template file directly (VerifyTool expects raw template content) - from app.core.memory.agent.utils.messages_tool import read_template_file - system_prompt = await read_template_file(file_path) - - - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - - template = Template(system_prompt) - system_prompt = template.render(history=history, sentence=context) - - # Process context to extract query and results - Query_small, Result_small, query = await Verify_messages_deal(context) - - # Build query list for verification - query_list = [] - for query_small, anser in zip(Query_small, Result_small, strict=False): - query_list.append({ - 'Query_small': query_small, - 'Answer_Small': anser - }) - - messages = { - "Query": query, - "Expansion_issue": query_list - } - - - - # Call verification workflow with LLM model ID from memory_config - verify_tool = VerifyTool( - system_prompt=system_prompt, - verify_data=messages, - llm_model_id=str(memory_config.llm_model_id) - ) - verify_result = await verify_tool.verify() - - # Parse LLM verification result with error handling - try: - messages_deal = await Retrieve_verify_tool_messages_deal( - verify_result, - history, - query - ) - except Exception as e: - logger.error( - f"Retrieve_verify_tool_messages_deal parsing failed: {e}", - exc_info=True - ) - # Fallback to avoid 500 errors - messages_deal = { - "data": { - "query": query, - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": history, - } - - logger.info(f"Verification result: {messages_deal}") - - # Emit intermediate output for frontend - return { - "status": "success", - "verified_data": messages_deal, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "verification", - "title": "Data Verification", - "result": messages_deal.get("split_result", "unknown"), - "reason": messages_deal.get("reason", ""), - "query": query, - "verified_count": len(query_list), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Verify failed: {e}", - exc_info=True - ) - return { - "status": "error", - "message": str(e), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "verified_data": { - "data": { - "query": "", - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": [], - } - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Verification', duration) diff --git a/api/app/core/memory/agent/mcp_server/models/__init__.py b/api/app/core/memory/agent/models/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/__init__.py rename to api/app/core/memory/agent/models/__init__.py diff --git a/api/app/core/memory/agent/mcp_server/models/problem_models.py b/api/app/core/memory/agent/models/problem_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/problem_models.py rename to api/app/core/memory/agent/models/problem_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/retrieval_models.py b/api/app/core/memory/agent/models/retrieval_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/retrieval_models.py rename to api/app/core/memory/agent/models/retrieval_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/summary_models.py b/api/app/core/memory/agent/models/summary_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/summary_models.py rename to api/app/core/memory/agent/models/summary_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/verification_models.py b/api/app/core/memory/agent/models/verification_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/verification_models.py rename to api/app/core/memory/agent/models/verification_models.py diff --git a/api/app/core/memory/agent/multimodal/oss_picture.py b/api/app/core/memory/agent/multimodal/oss_picture.py deleted file mode 100644 index b5b4bd6b..00000000 --- a/api/app/core/memory/agent/multimodal/oss_picture.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import sys -import traceback - -import requests - -# from qcloud_cos import CosConfig, CosS3Client -# from qcloud_cos.cos_exception import CosClientError, CosServiceError - -# from config.paths import BASE_DIR -BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0])) - -class OSSUploader: - """对象存储文件上传工具类""" - - def __init__(self, env): - api = { - "test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon", - "prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon" - } - self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon") - self.privacy = "false" - self.headers = { - "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' - 'AppleWebKit/537.36 (KHTML, like Gecko)' - ' Chrome/133.0.6833.84 Safari/537.36' - } - - @staticmethod - def _generate_object_key(file_path, prefix='xhs_'): - """ - 生成对象存储的Key - - :param file_path: 本地文件路径 - :param prefix: 存储前缀,用于分类存储 - :return: 生成的对象Key - """ - # 文件md5值.后缀名 - filename = os.path.basename(file_path) - filename = f"{filename}" - - # 组合成完整的对象Key - return f"{prefix}{filename}" - - def upload_image(self, file_name, prefix='jd_'): - """ - 上传文件到COS并返回可访问的URL - - :param file_url: 文件路径 - :param file_name: 文件名称 - :param media_type: 文件类型 - :param prefix: 存储前缀,用于分类存储 - :return: 文件访问URL - """ - # 检查文件是否存在 - - - - file_path = os.path.join(BASE_DIR, file_name) - - # response = requests.get(url, headers=self.headers, stream=True) - - # if response.status_code == 200: - # with open(file_path, "wb") as f: - # for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大 - # f.write(chunk) - # else: - # raise Exception(f"文件下载失败,{file_name}") - - # 生成对象Key - object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1]) - - try: - upload_response = requests.post( - self.api, - data={ - "privacy": self.privacy, - "fileName": object_key, - } - ) - - if upload_response.status_code != 200: - raise Exception('上传接口请求失败') - resp = upload_response.json() - name = resp["data"]["name"] - file_url = resp["data"]["path"] - policy = resp["data"]["policy"] - with open(file_path, 'rb') as f: - oss_push_resp = requests.post( - policy["host"], - files={ - "key": policy["dir"], - "OSSAccessKeyId": policy["accessid"], - "name": name, - "policy": policy["policy"], - "success_action_status": 200, - "signature": policy["signature"], - "file": f, - } - ) - if oss_push_resp.status_code == 200: - return file_url - raise Exception("OSS上传失败") - except Exception: - raise Exception(f"上传失败: \n{traceback.format_exc()}") - finally: - print('success') - # os.remove(file_path) - - -if __name__ == '__main__': - cos_uploader = OSSUploader("prod") - url =cos_uploader.upload_image('./example01.jpg') - print(url) diff --git a/api/app/core/memory/agent/multimodal/speech_model.py b/api/app/core/memory/agent/multimodal/speech_model.py deleted file mode 100644 index 2df32dd0..00000000 --- a/api/app/core/memory/agent/multimodal/speech_model.py +++ /dev/null @@ -1,121 +0,0 @@ -import asyncio -import re - -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize -from app.core.memory.agent.utils.messages_tool import read_template_file - -import requests -import json -import os -import time -# file_urls = [ -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav", -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav", -# ] -class Vico_recognition: - def __init__(self,file_urls): - self.api_key='' - self.backend_model_name='' - self.api_base='' - self.file_urls=file_urls - - # 提交文件转写任务,包含待转写文件url列表 - async def submit_task(self) -> str: - self.api_key, self.backend_model_name, self.api_base =await Voice_recognize() - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - data = { - "model": self.backend_model_name, - "input": {"file_urls": self.file_urls}, - "parameters": { - "channel_id": [0], - "vocabulary_id": "vocab-Xxxx", - }, - } - # 录音文件转写服务url - service_url = ( - "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription" - ) - response = requests.post( - service_url, headers=headers, data=json.dumps(data) - ) - - # 打印响应内容 - if response.status_code == 200: - return response.json()["output"]["task_id"] - else: - print("task failed!") - print(response.json()) - return None - - async def download_transcription_result(self, transcription_url): - """ - Args: - transcription_url (str): 转写结果文件URL - Returns: - dict: 转写结果内容 - """ - try: - response = requests.get(transcription_url) - response.raise_for_status() - return response.json() - except Exception as e: - print(f"下载转写结果失败: {e}") - return None - - # 循环查询任务状态直到成功 - async def wait_for_complete(self,task_id): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - - pending = True - while pending: - # 查询任务状态服务url - service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" - response = requests.post( - service_url, headers=headers - ) - if response.status_code == 200: - status = response.json()['output']['task_status'] - if status == 'SUCCEEDED': - print("task succeeded!") - pending = False - return response.json()['output']['results'] - elif status == 'RUNNING' or status == 'PENDING': - pass - else: - print("task failed!") - pending = False - else: - print("query failed!") - pending = False - time.sleep(0.1) - async def run(self): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - task_id=await self.submit_task() - result=await self.wait_for_complete(task_id) - result_context=[] - for i in result: - transcription_url=i['transcription_url'] - print(f"转写URL: {transcription_url}") - - # 下载并打印转写内容 - content = await self.download_transcription_result(transcription_url) - if content: - content=json.dumps(content, indent=2, ensure_ascii=False) - context=re.findall(r'"text": "(.*?)"', content) - result_context.append(context[0]) - result=''.join(result_context) - return (result) - - - - diff --git a/api/app/core/memory/agent/mcp_server/services/__init__.py b/api/app/core/memory/agent/services/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/__init__.py rename to api/app/core/memory/agent/services/__init__.py diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py new file mode 100644 index 00000000..6942d421 --- /dev/null +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -0,0 +1,277 @@ +""" +优化的LLM服务类,用于压缩和统一LLM调用 +""" + +import asyncio +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.core.logging_config import get_agent_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.llm_tools.openai_client import OpenAIClient + +T = TypeVar('T', bound=BaseModel) + +logger = get_agent_logger(__name__) + + +class OptimizedLLMService: + """ + 优化的LLM服务类,提供统一的LLM调用接口 + + 特性: + 1. 客户端复用 - 避免重复创建LLM客户端 + 2. 批量处理 - 支持并发处理多个请求 + 3. 错误处理 - 统一的错误处理和降级策略 + 4. 性能优化 - 缓存和连接池优化 + """ + + def __init__(self, db_session: Session): + self.db_session = db_session + self.client_factory = MemoryClientFactory(db_session) + self._client_cache: Dict[str, OpenAIClient] = {} + + def _get_cached_client(self, llm_model_id: str) -> OpenAIClient: + """获取缓存的LLM客户端,避免重复创建""" + if llm_model_id not in self._client_cache: + self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id) + return self._client_cache[llm_model_id] + + async def structured_response( + self, + llm_model_id: str, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 统一的结构化响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}") + + structured = await llm_client.response_structured( + messages=messages, + response_model=response_model + ) + + if structured is None: + logger.warning(f"LLM返回None,使用降级值") + return self._create_fallback_response(response_model, fallback_value) + + return structured + + except Exception as e: + logger.error(f"结构化响应失败: {e}", exc_info=True) + return self._create_fallback_response(response_model, fallback_value) + + async def batch_structured_response( + self, + llm_model_id: str, + requests: List[Dict[str, Any]], + response_model: Type[T], + max_concurrent: int = 5 + ) -> List[T]: + """ + 批量处理结构化响应 + + Args: + llm_model_id: LLM模型ID + requests: 请求列表,每个请求包含system_prompt等参数 + response_model: 响应模型类 + max_concurrent: 最大并发数 + + Returns: + 结构化响应列表 + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_single_request(request: Dict[str, Any]) -> T: + async with semaphore: + return await self.structured_response( + llm_model_id=llm_model_id, + system_prompt=request.get('system_prompt', ''), + response_model=response_model, + user_message=request.get('user_message'), + fallback_value=request.get('fallback_value') + ) + + tasks = [process_single_request(req) for req in requests] + return await asyncio.gather(*tasks) + + async def simple_response( + self, + llm_model_id: str, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 简单的文本响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + response = await llm_client.response(messages=messages) + + if not response or not response.strip(): + return fallback_message + + return response.strip() + + except Exception as e: + logger.error(f"简单响应失败: {e}", exc_info=True) + return fallback_message + + def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T: + """创建降级响应""" + try: + if fallback_value is not None: + if isinstance(fallback_value, response_model): + return fallback_value + elif isinstance(fallback_value, dict): + return response_model(**fallback_value) + + # 尝试创建空的响应模型 + if hasattr(response_model, 'root'): + # RootModel类型 + return response_model([]) + else: + # 普通BaseModel类型 + return response_model() + + except Exception as e: + logger.error(f"创建降级响应失败: {e}") + # 最后的降级策略 + if hasattr(response_model, 'root'): + return response_model([]) + else: + return response_model() + + def clear_cache(self): + """清理客户端缓存""" + self._client_cache.clear() + + +class LLMServiceMixin: + """ + LLM服务混入类,为节点提供便捷的LLM调用方法 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._llm_service: Optional[OptimizedLLMService] = None + + def get_llm_service(self, db_session: Session) -> OptimizedLLMService: + """获取LLM服务实例""" + if self._llm_service is None: + self._llm_service = OptimizedLLMService(db_session) + return self._llm_service + + async def call_llm_structured( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 便捷的结构化LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.structured_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + response_model=response_model, + user_message=user_message, + fallback_value=fallback_value + ) + + async def call_llm_simple( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 便捷的简单LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.simple_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + user_message=user_message, + fallback_message=fallback_message + ) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py b/api/app/core/memory/agent/services/parameter_builder.py similarity index 87% rename from api/app/core/memory/agent/mcp_server/services/parameter_builder.py rename to api/app/core/memory/agent/services/parameter_builder.py index d5305dc6..a58fcf1a 100644 --- a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py +++ b/api/app/core/memory/agent/services/parameter_builder.py @@ -4,22 +4,19 @@ Parameter Builder for constructing tool call arguments. This service provides tool-specific parameter transformation logic to build correct arguments for each tool type. """ - from typing import Any, Dict, Optional - from app.core.logging_config import get_agent_logger -from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class ParameterBuilder: """Service for building tool call arguments based on tool type.""" - + def __init__(self): """Initialize the parameter builder.""" logger.info("ParameterBuilder initialized") - + def build_tool_args( self, tool_name: str, @@ -28,9 +25,8 @@ class ParameterBuilder: search_switch: str, apply_id: str, group_id: str, - memory_config: MemoryConfig, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, + user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: """ Build tool arguments based on tool type. @@ -49,7 +45,6 @@ class ParameterBuilder: search_switch: Search routing parameter apply_id: Application identifier group_id: Group identifier - memory_config: MemoryConfig object containing all configuration storage_type: Storage type for the workspace (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) @@ -60,19 +55,18 @@ class ParameterBuilder: base_args = { "usermessages": tool_call_id, "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, + "group_id": group_id } - + # Always add storage_type and user_rag_memory_id (with defaults if None) base_args["storage_type"] = storage_type if storage_type is not None else "" base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else "" # Tool-specific argument construction - if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]: - # These tools expect dict context + if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']: + # Verify expects dict context return { - "context": content if isinstance(content, dict) else {"content": content}, + "context": content if isinstance(content, dict) else {}, **base_args } diff --git a/api/app/core/memory/agent/mcp_server/services/search_service.py b/api/app/core/memory/agent/services/search_service.py similarity index 75% rename from api/app/core/memory/agent/mcp_server/services/search_service.py rename to api/app/core/memory/agent/services/search_service.py index 47295f87..8a2e7cfe 100644 --- a/api/app/core/memory/agent/mcp_server/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -4,31 +4,21 @@ Search Service for executing hybrid search and processing results. This service provides clean search result processing with content extraction and deduplication. """ - -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Tuple, Optional from app.core.logging_config import get_agent_logger from app.core.memory.src.search import run_hybrid_search from app.core.memory.utils.data.text_utils import escape_lucene_query -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class SearchService: """Service for executing hybrid search and processing results.""" - - def __init__(self, memory_config: "MemoryConfig" = None): - """ - Initialize the search service. - - Args: - memory_config: Optional MemoryConfig for embedding model configuration. - If not provided, must be passed to execute_hybrid_search. - """ - self.memory_config = memory_config + + def __init__(self): + """Initialize the search service.""" logger.info("SearchService initialized") def extract_content_from_result(self, result: dict) -> str: @@ -103,49 +93,40 @@ class SearchService: self, group_id: str, question: str, - limit: int = 15, + limit: int = 5, search_type: str = "hybrid", include: Optional[List[str]] = None, - rerank_alpha: float = 0.6, - activation_boost_factor: float = 0.8, + rerank_alpha: float = 0.4, output_path: str = "search_results.json", return_raw_results: bool = False, - memory_config: "MemoryConfig" = None, + memory_config = None ) -> Tuple[str, str, Optional[dict]]: """ - Execute hybrid search with two-stage ranking. - - Stage 1: Filter by content relevance (BM25 + Embedding) - Stage 2: Rerank by activation values (ACTR) + Execute hybrid search and return clean content. Args: - group_id: Group identifier for filtering + group_id: Group identifier for filtering results question: Search query text - limit: Max results per category (default: 15) - search_type: "hybrid", "keyword", or "embedding" (default: "hybrid") - include: Result types (default: ["statements", "chunks", "entities", "summaries"]) - rerank_alpha: BM25 weight (default: 0.6) - activation_boost_factor: Activation impact on memory strength (default: 0.8) - output_path: JSON output path (default: "search_results.json") - return_raw_results: Return full metadata (default: False) - memory_config: MemoryConfig for embedding model + limit: Maximum number of results to return (default: 5) + search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") + include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) + rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) + output_path: Path to save search results (default: "search_results.json") + return_raw_results: If True, also return the raw search results as third element (default: False) + memory_config: Memory configuration object (required) Returns: - Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results) + Tuple of (clean_content, cleaned_query, raw_results) + raw_results is None if return_raw_results=False """ if include is None: include = ["statements", "chunks", "entities", "summaries"] - - # Use provided memory_config or fall back to instance config - config = memory_config or self.memory_config - if not config: - raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search") - + # Clean query cleaned_query = self.clean_query(question) - + try: - # Execute search using memory_config + # Execute search answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, @@ -153,9 +134,8 @@ class SearchService: limit=limit, include=include, output_path=output_path, - memory_config=config, - rerank_alpha=rerank_alpha, - activation_boost_factor=activation_boost_factor, + memory_config=memory_config, + rerank_alpha=rerank_alpha ) # Extract results based on search type and include parameter diff --git a/api/app/core/memory/agent/mcp_server/services/session_service.py b/api/app/core/memory/agent/services/session_service.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/session_service.py rename to api/app/core/memory/agent/services/session_service.py diff --git a/api/app/core/memory/agent/mcp_server/services/template_service.py b/api/app/core/memory/agent/services/template_service.py similarity index 94% rename from api/app/core/memory/agent/mcp_server/services/template_service.py rename to api/app/core/memory/agent/services/template_service.py index 95223f0b..1bf86375 100644 --- a/api/app/core/memory/agent/mcp_server/services/template_service.py +++ b/api/app/core/memory/agent/services/template_service.py @@ -3,12 +3,22 @@ Template Service for loading and rendering Jinja2 templates. This service provides centralized template management with caching and error handling. """ + import os from functools import lru_cache -from typing import Optional -from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound -from app.core.logging_config import get_agent_logger, log_prompt_rendering +from jinja2 import ( + Environment, + FileSystemLoader, + Template, + TemplateNotFound, +) + +from app.core.logging_config import ( + get_agent_logger, + log_prompt_rendering, +) + logger = get_agent_logger(__name__) diff --git a/api/app/core/memory/agent/utils/__init__.py b/api/app/core/memory/agent/utils/__init__.py deleted file mode 100644 index 2b77e240..00000000 --- a/api/app/core/memory/agent/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Agent utilities.""" - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -__all__ = [ - "MultimodalProcessor", -] diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py new file mode 100644 index 00000000..fddd54f6 --- /dev/null +++ b/api/app/core/memory/agent/utils/llm_client_pool.py @@ -0,0 +1,56 @@ + +import asyncio +from typing import Dict, Optional +from app.core.memory.utils.llm.llm_utils import get_llm_client_fast +from app.db import get_db +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class LLMClientPool: + """LLM客户端连接池""" + + def __init__(self, max_size: int = 5): + self.max_size = max_size + self.pools: Dict[str, asyncio.Queue] = {} + self.active_clients: Dict[str, int] = {} + + async def get_client(self, llm_model_id: str): + """获取LLM客户端""" + if llm_model_id not in self.pools: + self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) + self.active_clients[llm_model_id] = 0 + + pool = self.pools[llm_model_id] + + try: + # 尝试从池中获取客户端 + client = pool.get_nowait() + logger.debug(f"从池中获取LLM客户端: {llm_model_id}") + return client + except asyncio.QueueEmpty: + # 池为空,创建新客户端 + if self.active_clients[llm_model_id] < self.max_size: + db_session = next(get_db()) + client = get_llm_client_fast(llm_model_id, db_session) + self.active_clients[llm_model_id] += 1 + logger.debug(f"创建新LLM客户端: {llm_model_id}") + return client + else: + # 等待可用客户端 + logger.debug(f"等待LLM客户端可用: {llm_model_id}") + return await pool.get() + + async def return_client(self, llm_model_id: str, client): + """归还LLM客户端到池中""" + if llm_model_id in self.pools: + try: + self.pools[llm_model_id].put_nowait(client) + logger.debug(f"归还LLM客户端到池: {llm_model_id}") + except asyncio.QueueFull: + # 池已满,丢弃客户端 + self.active_clients[llm_model_id] -= 1 + logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") + +# 全局客户端池 +llm_client_pool = LLMClientPool() diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index ec22b628..8dd2f1d3 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,40 +1,12 @@ -import asyncio -import json -import logging import os from collections import defaultdict from typing import Annotated, TypedDict -from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.config.config_utils import ( - get_picture_config, - get_voice_config, -) - -# Removed global variable imports - use dependency injection instead -from dotenv import load_dotenv from langchain_core.messages import AnyMessage from langgraph.graph import add_messages -from openai import OpenAI PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -logger = logging.getLogger(__name__) -load_dotenv() - - -async def picture_model_requests(image_url): - ''' - - Args: - image_url: - Returns: - - ''' - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 ' - system_prompt = await read_template_file(file_path) - result = await Picture_recognize(image_url,system_prompt) - return (result) class WriteState(TypedDict): ''' Langgrapg Writing TypedDict @@ -44,39 +16,69 @@ class WriteState(TypedDict): apply_id:str group_id:str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] + memory_config: object + write_result: dict + data:str class ReadState(TypedDict): - ''' - Langgrapg READING TypedDict - name: - id:user id - loop_count:Traverse times - search_switch:type - config_id: configuration id for filtering results - errors: list of errors that occurred during workflow execution - ''' - messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息 - name: str - id: str - loop_count:int + """ + LangGraph 工作流状态定义 + + Attributes: + messages: 消息列表,支持自动追加 + loop_count: 遍历次数 + search_switch: 搜索类型开关 + group_id: 组标识 + config_id: 配置ID,用于过滤结果 + data: 从content_input_node传递的内容数据 + spit_data: 从Split_The_Problem传递的分解结果 + tool_calls: 工具调用请求列表 + tool_results: 工具执行结果列表 + memory_config: 内存配置对象 + """ + messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 + loop_count: int search_switch: str - user_id: str - apply_id: str group_id: str config_id: str - errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] - - + data: str # 新增字段用于传递内容 + spit_data: dict # 新增字段用于传递问题分解结果 + problem_extension:dict + storage_type: str + user_rag_memory_id: str + llm_id: str + embedding_id: str + memory_config: object # 新增字段用于传递内存配置对象 + retrieve:dict + RetrieveSummary: dict + InputSummary: dict + verify: dict + SummaryFails: dict + summary: dict class COUNTState: - ''' - The number of times the workflow dialogue retrieval content has no correct message recall traversal - ''' + """ + 工作流对话检索内容计数器 + + 用于记录工作流对话检索内容没有正确消息召回遍历的次数。 + """ + def __init__(self, limit: int = 5): + """ + 初始化计数器 + + Args: + limit: 最大计数限制,默认为5 + """ self.total: int = 0 # 当前累加值 self.limit: int = limit # 最大上限 - def add(self, value: int = 1): - """累加数字,如果达到上限就保持最大值""" + def add(self, value: int = 1) -> None: + """ + 累加数字,如果达到上限就保持最大值 + + Args: + value: 要累加的值,默认为1 + """ self.total += value print(f"[COUNTState] 当前值: {self.total}") if self.total >= self.limit: @@ -84,21 +86,19 @@ class COUNTState: self.total = self.limit # 达到上限不再增加 def get_total(self) -> int: - """获取当前累加值""" + """ + 获取当前累加值 + + Returns: + 当前累加值 + """ return self.total - def reset(self): + def reset(self) -> None: """手动重置累加值""" self.total = 0 print("[COUNTState] 已重置为 0") - -def merge_to_key_value_pairs(data, query_key, result_key): - grouped = defaultdict(list) - for item in data: - grouped[item[query_key]].append(item[result_key]) - return [{key: values} for key, values in grouped.items()] - def deduplicate_entries(entries): seen = set() deduped = [] @@ -109,70 +109,37 @@ def deduplicate_entries(entries): deduped.append(entry) return deduped +def merge_to_key_value_pairs(data, query_key, result_key): + grouped = defaultdict(list) + for item in data: + grouped[item[query_key]].append(item[result_key]) + return [{key: values} for key, values in grouped.items()] -async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str: +def convert_extended_question_to_question(data): """ - Updated to eliminate global variables in favor of explicit parameters. - + 递归地将数据中的 extended_question 字段转换为 question 字段 + Args: - image_path: Path to image file - PROMPT_TICKET_EXTRACTION: Extraction prompt - picture_model_name: Picture model name (required, no longer from global variables) + data: 要转换的数据(可能是字典、列表或其他类型) + + Returns: + 转换后的数据 """ - try: - model_config = get_picture_config(picture_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base=model_config['api_base'] - - logger.info(f"model_name: {backend_model_name}") - logger.info(f"api_key set: {'yes' if api_key else 'no'}") - logger.info(f"base_url: {model_config['api_base']}") - - client = OpenAI( - api_key=api_key, base_url=api_base, - ) - completion = client.chat.completions.create( - model=backend_model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url":image_path, - }, - {"type": "text", - "text": PROMPT_TICKET_EXTRACTION} - ] - } - ]) - picture_text = completion.choices[0].message.content - picture_text = picture_text.replace('```json', '').replace('```', '') - picture_text = json.loads(picture_text) - return (picture_text['statement']) - -async def Voice_recognize(voice_model_name: str): - """ - Updated to eliminate global variables in favor of explicit parameters. - - Args: - voice_model_name: Voice model name (required, no longer from global variables) - """ - try: - model_config = get_voice_config(voice_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base = model_config['api_base'] - return api_key,backend_model_name,api_base - - + if isinstance(data, dict): + # 创建新字典来存储转换后的数据 + converted = {} + for key, value in data.items(): + if key == 'extended_question': + # 将 extended_question 转换为 question + converted['question'] = convert_extended_question_to_question(value) + else: + # 递归处理其他字段 + converted[key] = convert_extended_question_to_question(value) + return converted + elif isinstance(data, list): + # 递归处理列表中的每个元素 + return [convert_extended_question_to_question(item) for item in data] + else: + # 其他类型直接返回 + return data \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/mcp_tools.py b/api/app/core/memory/agent/utils/mcp_tools.py deleted file mode 100644 index 7ede9843..00000000 --- a/api/app/core/memory/agent/utils/mcp_tools.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -from app.core.config import settings - -def get_mcp_server_config(): - """ - Get the MCP server configuration. - - Uses MCP_SERVER_URL environment variable if set (for Docker), - otherwise falls back to SERVER_IP and MCP_PORT (for local development). - """ - # Get MCP port from environment (default: 8081) - mcp_port = os.getenv("MCP_PORT", "8081") - - # In Docker: MCP_SERVER_URL=http://mcp-server:8081 - # In local dev: uses SERVER_IP (127.0.0.1 or localhost) - mcp_server_url = os.getenv("MCP_SERVER_URL") - - if mcp_server_url: - # Docker environment: use full URL from environment - base_url = mcp_server_url - else: - # Local development: build URL from SERVER_IP and MCP_PORT - base_url = f"http://{settings.SERVER_IP}:{mcp_port}" - - mcp_server_config = { - "data_flow": { - "url": f"{base_url}/sse", - "transport": "sse", - "timeout": 15000, - "sse_read_timeout": 15000, - } - } - return mcp_server_config diff --git a/api/app/core/memory/agent/utils/messages_tool.py b/api/app/core/memory/agent/utils/messages_tool.py deleted file mode 100644 index 769e795a..00000000 --- a/api/app/core/memory/agent/utils/messages_tool.py +++ /dev/null @@ -1,260 +0,0 @@ -import json -import logging -import re -from typing import Any, List - -from app.core.logging_config import get_agent_logger -from langchain_core.messages import AnyMessage - -logger = get_agent_logger(__name__) - - -def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]: - out = [] - for m in msgs: - if hasattr(m, "content"): - out.append({"role": "user", "content": getattr(m, "content", "")}) - elif isinstance(m, dict) and "role" in m and "content" in m: - out.append(m) - else: - out.append({"role": "user", "content": str(m)}) - return out - - -def _extract_content(resp: Any) -> str: - """Extract LLM content and sanitize to raw JSON/text. - - - Supports both object and dict response shapes. - - Removes leading role labels (e.g., "Assistant:"). - - Strips Markdown code fences like ```json ... ```. - - Attempts to isolate the first valid JSON array/object block when extra text is present. - """ - - def _to_text(r: Any) -> str: - try: - # 对象形式: resp.choices[0].message.content - if hasattr(r, "choices") and getattr(r, "choices", None): - msg = r.choices[0].message - if hasattr(msg, "content"): - return msg.content - if isinstance(msg, dict) and "content" in msg: - return msg["content"] - # 字典形式: resp["choices"][0]["message"]["content"] - if isinstance(r, dict): - return r.get("choices", [{}])[0].get("message", {}).get("content", "") - except Exception: - pass - return str(r) - - def _clean_text(text: str) -> str: - s = str(text).strip() - # 移除可能的角色前缀 - s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s) - # 提取 ```json ... ``` 代码块 - m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I) - if m: - s = m.group(1).strip() - # 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段 - if not (s.startswith("{") or s.startswith("[")): - left = s.find("[") - right = s.rfind("]") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - else: - left = s.find("{") - right = s.rfind("}") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - return s - - raw = _to_text(resp) - return _clean_text(raw) - -def Resolve_username(usermessages): - ''' - Extract username - Args: - usermessages: user name - - Returns: - - ''' - usermessages = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages[:-1]) - return sessionid - - -# TODO: USE app.core.memory.src.utils.render_template instead -async def read_template_file(template_path: str) -> str: - """ - 读取模板文件 - - Args: - template_path: 模板文件路径 - - Returns: - 模板内容字符串 - - Note: - 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 - """ - try: - with open(template_path, "r", encoding="utf-8") as f: - return f.read() - except FileNotFoundError: - logger.error(f"模板文件未找到: {template_path}") - raise - except IOError as e: - logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) - raise - - -async def Problem_Extension_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - extent_quest = [] - original = context.get('original', '') - messages = context.get('context', '') - - # Handle empty or non-string messages - if not messages: - return extent_quest, original - - if isinstance(messages, str): - try: - messages = json.loads(messages) - except json.JSONDecodeError: - # If JSON parsing fails, return empty list - return extent_quest, original - - if isinstance(messages, list): - for message in messages: - question = message.get('question', '') - type = message.get('type', '') - extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"}) - - return extent_quest, original - - -async def Retriev_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}") - - if isinstance(context, dict): - logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}") - if 'context' in context or 'original' in context: - content = context.get('context', {}) - original = context.get('original', '') - logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'") - return content, original - - # Return empty defaults if context is not a dict or doesn't have expected keys - logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults") - return {}, '' - -async def Verify_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - - query = context['context']['Query'] - Query_small_list = context['context']['Expansion_issue'] - Result_small = [] - Query_small = [] - for i in Query_small_list: - Result_small.append(i['Answer_Small'][0]) - Query_small.append(i['Query_small']) - return Query_small, Result_small, query - - -async def Summary_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - query = re.findall(r'"query": (.*?),', messages)[0] - query = query.replace('[', '').replace(']', '').strip() - matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages) - answer_small_texts = [] - for m in matches: - try: - parsed = json.loads(m) - for item in parsed: - answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', '')) - except Exception: - answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', '')) - - return answer_small_texts, query - - -async def VerifyTool_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - content_messages = messages.split('"context":')[1].replace('""', '"') - messages = str(content_messages).split("name='Retrieve'")[0] - query = re.findall('"Query": "(.*?)"', messages)[0] - Query_small = re.findall('"Query_small": "(.*?)"', messages) - Result_small = re.findall('"Result_small": "(.*?)"', messages) - return Query_small, Result_small, query - - -async def Retrieve_Summary_messages_deal(context): - pass - - -async def Retrieve_verify_tool_messages_deal(context, history, query): - ''' - Extract data - Args: - context: - Returns: - ''' - results = [] - # 统一转为字符串,避免 None 或非字符串导致正则报错 - text = str(context) - blocks = re.findall(r'\{(.*?)\}', text, flags=re.S) - for block in blocks: - query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block) - answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block) - status = re.search(r'"status"\s*:\s*"([^"]*)"', block) - query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block) - - results.append({ - "query_small": query_small.group(1) if query_small else None, - "answer_small": answer_small.group(1) if answer_small else None, - # 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误 - "status": status.group(1) if status else "", - "query_answer": query_answer.group(1) if query_answer else None - }) - result = [] - for r in results: - # 统一按字符串判定状态,兼容大小写和缺失情况 - status_str = str(r.get('status', '')).strip().lower() - if status_str == 'false': - continue - else: - result.append(r) - split_result = 'failed' if not result else 'success' - result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "", - "history": history} - return result diff --git a/api/app/core/memory/agent/utils/messages_tools.py b/api/app/core/memory/agent/utils/messages_tools.py new file mode 100644 index 00000000..db95319f --- /dev/null +++ b/api/app/core/memory/agent/utils/messages_tools.py @@ -0,0 +1,194 @@ +from typing import List, Dict, Any +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def read_template_file(template_path: str) -> str: + """ + 读取模板文件 + + Args: + template_path: 模板文件路径 + + Returns: + 模板内容字符串 + + Note: + 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 + """ + try: + with open(template_path, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + logger.error(f"模板文件未找到: {template_path}") + raise + except IOError as e: + logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) + raise + +def reorder_output_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 重新排序输出结果,将 retrieval_summary 类型的数据放到最后面 + + Args: + results: 原始输出结果列表 + + Returns: + 重新排序后的结果列表 + """ + retrieval_summaries = [] + other_results = [] + + # 分离 retrieval_summary 和其他类型的结果 + for result in results: + if 'summary' in result.get('type'): + retrieval_summaries.append(result) + else: + other_results.append(result) + + # 将 retrieval_summary 放到最后 + return other_results + retrieval_summaries + +def optimize_search_results(intermediate_outputs): + """ + 优化检索结果,合并多个搜索结果,过滤空结果,统一格式 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 优化后的检索结果列表 + """ + optimized_results = [] + + for item in intermediate_outputs: + if not item or item == [] or item == {}: + continue + + # 检查是否是搜索结果类型 + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + + # 如果 raw_results 为空,跳过 + if not raw_results or raw_results == [] or raw_results == {}: + continue + + # 创建优化后的结果结构 + optimized_item = { + "type": "search_result", + "title": f"检索结果 ({item.get('index', 1)}/{item.get('total', 1)})", + "query": item.get('query', ''), + "raw_results": {}, + "index": item.get('index', 1), + "total": item.get('total', 1) + } + + # 合并所有搜索结果类型到一个 raw_results 中 + merged_raw_results = {} + + # 处理 time_search + if 'time_search' in raw_results and raw_results['time_search']: + merged_raw_results['time_search'] = raw_results['time_search'] + + # 处理 keyword_search + if 'keyword_search' in raw_results and raw_results['keyword_search']: + merged_raw_results['keyword_search'] = raw_results['keyword_search'] + + # 处理 embedding_search + if 'embedding_search' in raw_results and raw_results['embedding_search']: + merged_raw_results['embedding_search'] = raw_results['embedding_search'] + + # 处理 combined_summary + if 'combined_summary' in raw_results and raw_results['combined_summary']: + merged_raw_results['combined_summary'] = raw_results['combined_summary'] + + # 处理 reranked_results + if 'reranked_results' in raw_results and raw_results['reranked_results']: + merged_raw_results['reranked_results'] = raw_results['reranked_results'] + + # 如果合并后的结果不为空,添加到优化结果中 + if merged_raw_results: + optimized_item['raw_results'] = merged_raw_results + optimized_results.append(optimized_item) + else: + # 非搜索结果类型,直接添加 + optimized_results.append(item) + + return optimized_results + + +def merge_multiple_search_results(intermediate_outputs): + """ + 将多个搜索结果合并为一个统一的搜索结果 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 合并后的结果列表 + """ + search_results = [] + other_results = [] + + # 分离搜索结果和其他结果 + for item in intermediate_outputs: + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + # 只保留有内容的搜索结果 + if raw_results and raw_results != [] and raw_results != {}: + search_results.append(item) + else: + other_results.append(item) + + # 如果没有搜索结果,返回原始结果 + if not search_results: + return intermediate_outputs + + # 如果只有一个搜索结果,优化格式后返回 + if len(search_results) == 1: + optimized = optimize_search_results(search_results) + return other_results + optimized + + # 合并多个搜索结果 + merged_raw_results = {} + all_queries = [] + + for result in search_results: + query = result.get('query', '') + if query: + all_queries.append(query) + + raw_results = result.get('raw_results', {}) + + # 合并各种搜索类型的结果 + for search_type in ['time_search', 'keyword_search', 'embedding_search', 'combined_summary', + 'reranked_results']: + if search_type in raw_results and raw_results[search_type]: + if search_type not in merged_raw_results: + merged_raw_results[search_type] = raw_results[search_type] + else: + # 如果是字典类型,需要合并 + if isinstance(raw_results[search_type], dict) and isinstance(merged_raw_results[search_type], dict): + for key, value in raw_results[search_type].items(): + if key not in merged_raw_results[search_type]: + merged_raw_results[search_type][key] = value + elif isinstance(value, list) and isinstance(merged_raw_results[search_type][key], list): + merged_raw_results[search_type][key].extend(value) + elif isinstance(raw_results[search_type], list): + if isinstance(merged_raw_results[search_type], list): + merged_raw_results[search_type].extend(raw_results[search_type]) + else: + merged_raw_results[search_type] = raw_results[search_type] + + # 创建合并后的结果 + if merged_raw_results: + merged_result = { + "type": "search_result", + "title": f"合并检索结果 (共{len(search_results)}个查询)", + "query": " | ".join(all_queries), + "raw_results": merged_raw_results, + "index": 1, + "total": 1 + } + return other_results + [merged_result] + + return other_results diff --git a/api/app/core/memory/agent/utils/model_tool.py b/api/app/core/memory/agent/utils/model_tool.py deleted file mode 100644 index 969a2a91..00000000 --- a/api/app/core/memory/agent/utils/model_tool.py +++ /dev/null @@ -1,38 +0,0 @@ - - -# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# sys.path.insert(0, project_root) - -# load_dotenv() - -# async def llm_client_chat(messages: List[dict]) -> str: -# """使用 OpenAI 兼容接口进行对话,返回内容字符串。""" -# try: -# cfg = get_model_config(SELECTED_LLM_ID) -# rb_config = RedBearModelConfig( -# model_name=cfg["model_name"], -# provider=cfg["provider"], -# api_key=cfg["api_key"], -# base_url=cfg["base_url"], -# ) -# client = OpenAIClient(model_config=rb_config, type_="chat") - -# except Exception as e: -# logger.error(f"获取模型配置失败:{e}") -# err = f"获取模型配置失败:{str(e)}。请检查!!!" -# return err -# try: -# response = await client.chat(messages) -# print(f"model_tool's llm_client_chat response ======>:\n {response}") -# return _extract_content(response) -# # return _extract_content(result) -# except Exception as e: -# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。") -# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。" - -# async def main(image_url): -# await llm_client_chat(image_url) -# -# # 运行主函数 -# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav'])) -# diff --git a/api/app/core/memory/agent/utils/multimodal.py b/api/app/core/memory/agent/utils/multimodal.py deleted file mode 100644 index 0fc52634..00000000 --- a/api/app/core/memory/agent/utils/multimodal.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Multimodal input processor for handling image and audio content. - -This module provides utilities for detecting and processing multimodal inputs -(images and audio files) by converting them to text using appropriate models. -""" - -import logging -from typing import List - -from app.core.memory.agent.multimodal.speech_model import Vico_recognition -from app.core.memory.agent.utils.llm_tools import picture_model_requests - -logger = logging.getLogger(__name__) - - -class MultimodalProcessor: - """ - Processor for handling multimodal inputs (images and audio). - - This class detects image and audio file paths in input content and converts - them to text using appropriate recognition models. - """ - - # Supported file extensions - IMAGE_EXTENSIONS = ['.jpg', '.png'] - AUDIO_EXTENSIONS = [ - 'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov', - 'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv' - ] - - def __init__(self): - """Initialize the multimodal processor.""" - pass - - def is_image(self, content: str) -> bool: - """ - Check if content is an image file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported image extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_image("photo.jpg") - True - >>> processor.is_image("document.pdf") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS) - - def is_audio(self, content: str) -> bool: - """ - Check if content is an audio file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported audio extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_audio("recording.mp3") - True - >>> processor.is_audio("video.mp4") - True - >>> processor.is_audio("document.txt") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS) - - async def process_input(self, content: str) -> str: - """ - Process input content, converting images/audio to text if needed. - - This method detects if the input is an image or audio file and converts - it to text using the appropriate recognition model. If processing fails - or the content is not multimodal, it returns the original content. - - Args: - content: Input string (may be file path or regular text) - - Returns: - Text content (original or converted from image/audio) - - Examples: - >>> processor = MultimodalProcessor() - >>> await processor.process_input("photo.jpg") - "Recognized text from image..." - - >>> await processor.process_input("Hello world") - "Hello world" - """ - if not isinstance(content, str): - logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}") - return str(content) - - try: - # Check for image input - if self.is_image(content): - logger.info(f"[MultimodalProcessor] Detected image input: {content}") - result = await picture_model_requests(content) - logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...") - return result - - # Check for audio input - if self.is_audio(content): - logger.info(f"[MultimodalProcessor] Detected audio input: {content}") - result = await Vico_recognition([content]).run() - logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...") - return result - - except Exception as e: - logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True) - logger.info("[MultimodalProcessor] Falling back to original content") - return content - - # Return original content if not multimodal - return content diff --git a/api/app/core/memory/agent/utils/performance_monitor.py b/api/app/core/memory/agent/utils/performance_monitor.py new file mode 100644 index 00000000..d2d9fdfa --- /dev/null +++ b/api/app/core/memory/agent/utils/performance_monitor.py @@ -0,0 +1,56 @@ + +import time +import json +from collections import defaultdict +from typing import Dict, List +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class ProblemExtensionMonitor: + """Problem_Extension性能监控器""" + + def __init__(self): + self.metrics = defaultdict(list) + self.slow_queries = [] + self.error_count = 0 + + def record_execution(self, duration: float, question_count: int, success: bool): + """记录执行指标""" + self.metrics['durations'].append(duration) + self.metrics['question_counts'].append(question_count) + + if not success: + self.error_count += 1 + + # 记录慢查询(超过10秒) + if duration > 10.0: + self.slow_queries.append({ + 'duration': duration, + 'question_count': question_count, + 'timestamp': time.time() + }) + + def get_stats(self) -> Dict: + """获取统计信息""" + durations = self.metrics['durations'] + if not durations: + return {"message": "暂无数据"} + + return { + "total_executions": len(durations), + "avg_duration": sum(durations) / len(durations), + "max_duration": max(durations), + "min_duration": min(durations), + "slow_queries_count": len(self.slow_queries), + "error_rate": self.error_count / len(durations) if durations else 0, + "recent_slow_queries": self.slow_queries[-5:] # 最近5个慢查询 + } + + def log_stats(self): + """记录统计信息到日志""" + stats = self.get_stats() + logger.info(f"Problem_Extension性能统计: {json.dumps(stats, indent=2)}") + +# 全局监控器实例 +performance_monitor = ProblemExtensionMonitor() diff --git a/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 new file mode 100644 index 00000000..a0e21fbd --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 @@ -0,0 +1,81 @@ + +你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则: + +角色: +- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。 +- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。 +- 如果历史信息或上下文与当前问题无关,可忽略。 + +--- + +### 历史信息参考 +在生成扩展问题时,你可以参考以下历史数据(如果提供): +- 历史对话或任务的主题; +- 历史中出现的关键实体(时间、人物、地点、研究主题等); +- 历史中已解答的问题(避免重复); +- 历史推理链(保持逻辑一致性)。 + +> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 +输入历史信息内容:{{history}} + +## User Input +{% if questions is string %} +{{ questions }} +{% else %} +{% for question in questions %} +- {{ question }} +{% endfor %} +{% endif %} + +需求: +- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。 +- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。 +- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。 +- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。 +- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。 +- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。 +- 子问题数量不超过4个。 +- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 + 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] + 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? + + + +输出要求: +- 仅输出 JSON 数组,不要包含任何解释或代码块。 +- 每个元素包含: + - `original_question`: 原始问题 + - `extended_question`: 扩展后的问题 + - `type`: 类型(事实检索/澄清/定义/比较/行动建议) + - `reason`: 生成该扩展问题的简短理由 +- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。 + +示例: +输入: +[ + "问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳", +] + +输出: +[ + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?", + "type": "多跳", + "reason": "输出原问题的关键要素" + }, + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?", + "type": "多跳", + "reason": "输出原问题的关键要素" + } +] +**Output format** +**CRITICAL JSON FORMATTING REQUIREMENTS:** +1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes +2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") +3. Ensure all JSON strings are properly closed and comma-separated +4. Do not include line breaks within JSON string values + +The output language should always be the same as the input language.{{ json_schema }} diff --git a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 index 1fa71df3..5fbe8574 100644 --- a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 @@ -1,13 +1,10 @@ # 角色 你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。 - # 任务 根据提供的上下文信息回答用户的问题。 - # 输入信息 - 历史对话:{{history}} - 检索信息:{{retrieve_info}} - ## User Query {{query}} diff --git a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 index f4d4665c..d6ad8cab 100644 --- a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 @@ -9,8 +9,8 @@ 3. 判断Answer_Small和Query_Small之间分析出来的关系状态 4. 如果是True保留,否则不要相对应的问题和回答 5. 输出,需要严格按照模版 -输入:{{history}} -历史消息:{"history":{{sentence}}} +输入:{{sentence}} +历史消息:{"history":{{history}}} ### 第一步 获取用户的输入 获取用户的输入提取对应的Query_Small和Answer_Small ### 第二步 分析验证 diff --git a/api/app/core/memory/agent/utils/session_tools.py b/api/app/core/memory/agent/utils/session_tools.py new file mode 100644 index 00000000..b2d4f0ff --- /dev/null +++ b/api/app/core/memory/agent/utils/session_tools.py @@ -0,0 +1,169 @@ +""" +Session Service for managing user sessions and conversation history. + +This service provides clean Redis interactions with error handling and +session management utilities. +""" +from typing import List, Optional + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.redis_tool import RedisSessionStore + + +logger = get_agent_logger(__name__) + + +class SessionService: + """Service for managing user sessions and conversation history.""" + + def __init__(self, store: RedisSessionStore): + """ + Initialize the session service. + + Args: + store: Redis session store instance + """ + self.store = store + logger.info("SessionService initialized") + + def resolve_user_id(self, session_string: str) -> str: + """ + Extract user ID from session string. + + Handles formats like: + - 'call_id_user123' -> 'user123' + - 'prefix_id_user456_suffix' -> 'user456_suffix' + + Args: + session_string: Session identifier string + + Returns: + Extracted user ID + """ + try: + # Split by '_id_' and take everything after it + parts = session_string.split('_id_') + if len(parts) > 1: + return parts[1] + + # Fallback: return original string + return session_string + + except Exception as e: + logger.warning( + f"Failed to parse user ID from session string '{session_string}': {e}" + ) + return session_string + + async def get_history( + self, + user_id: str, + apply_id: str, + group_id: str + ) -> List[dict]: + """ + Retrieve conversation history from Redis. + + Args: + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + + Returns: + List of conversation history items with Query and Answer keys + Returns empty list if no history found or on error + """ + try: + history = self.store.find_user_apply_group(user_id, apply_id, group_id) + + # Validate history structure + if not isinstance(history, list): + logger.warning( + f"Invalid history format for user {user_id}, " + f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + ) + return [] + + return history + + except Exception as e: + logger.error( + f"Failed to retrieve history for user {user_id}, " + f"apply {apply_id}, group {group_id}: {e}", + exc_info=True + ) + # Return empty list on error to allow execution to continue + return [] + + async def save_session( + self, + user_id: str, + query: str, + apply_id: str, + group_id: str, + ai_response: str + ) -> Optional[str]: + """ + Save conversation turn to Redis. + + Args: + user_id: User identifier + query: User query/message + apply_id: Application identifier + group_id: Group identifier + ai_response: AI response/answer + + Returns: + Session ID if successful, None on error + """ + try: + # Validate required fields + if not user_id: + logger.warning("Cannot save session: user_id is empty") + return None + + if not query: + logger.warning("Cannot save session: query is empty") + return None + + # Save session + session_id = self.store.save_session( + userid=user_id, + messages=query, + apply_id=apply_id, + group_id=group_id, + aimessages=ai_response + ) + + logger.info(f"Session saved successfully: {session_id}") + return session_id + + except Exception as e: + logger.error( + f"Failed to save session for user {user_id}: {e}", + exc_info=True + ) + return None + + async def cleanup_duplicates(self) -> int: + """ + Remove duplicate session entries. + + Duplicates are identified by matching: + - sessionid + - user_id (id field) + - group_id + - messages + - aimessages + + Returns: + Number of duplicate sessions deleted + """ + try: + deleted_count = self.store.delete_duplicate_sessions() + logger.info(f"Cleaned up {deleted_count} duplicate sessions") + return deleted_count + + except Exception as e: + logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True) + return 0 diff --git a/api/app/core/memory/agent/utils/template_tools.py b/api/app/core/memory/agent/utils/template_tools.py new file mode 100644 index 00000000..854c5383 --- /dev/null +++ b/api/app/core/memory/agent/utils/template_tools.py @@ -0,0 +1,117 @@ +""" +Template Service for loading and rendering Jinja2 templates. + +This service provides centralized template management with caching and error handling. +""" +# 标准库 +import os +from functools import lru_cache + +from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound + +from app.core.logging_config import get_agent_logger, log_prompt_rendering + + +logger = get_agent_logger(__name__) + + +class TemplateRenderError(Exception): + """Exception raised when template rendering fails.""" + + def __init__(self, template_name: str, error: Exception, variables: dict): + self.template_name = template_name + self.error = error + self.variables = variables + super().__init__( + f"Failed to render template '{template_name}': {str(error)}" + ) + + +class TemplateService: + """Service for loading and rendering Jinja2 templates with caching.""" + + def __init__(self, template_root: str): + """ + Initialize the template service. + + Args: + template_root: Root directory containing template files + """ + self.template_root = template_root + self.env = Environment( + loader=FileSystemLoader(template_root), + autoescape=False # Disable autoescape for prompt templates + ) + logger.info(f"TemplateService initialized with root: {template_root}") + + @lru_cache(maxsize=128) + def _load_template(self, template_name: str) -> Template: + """ + Load a template from disk with caching. + + Args: + template_name: Relative path to template file + + Returns: + Loaded Jinja2 Template object + + Raises: + TemplateNotFound: If template file doesn't exist + """ + try: + return self.env.get_template(template_name) + except TemplateNotFound as e: + expected_path = os.path.join(self.template_root, template_name) + logger.error( + f"Template not found: {template_name}. " + f"Expected path: {expected_path}" + ) + raise + + async def render_template( + self, + template_name: str, + operation_name: str, + **variables + ) -> str: + """ + Load and render a Jinja2 template. + + Args: + template_name: Relative path to template file + operation_name: Name for logging (e.g., "split_the_problem") + **variables: Template variables to render + + Returns: + Rendered template string + + Raises: + TemplateRenderError: If template loading or rendering fails + """ + try: + # Load template (cached) + template = self._load_template(template_name) + + # Render template + rendered = template.render(**variables) + + # Log rendered prompt + log_prompt_rendering(operation_name, rendered) + + return rendered + + except TemplateNotFound as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): Template not found", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) + + except Exception as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): {e}", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) diff --git a/api/app/core/memory/agent/utils/type_classifier.py b/api/app/core/memory/agent/utils/type_classifier.py index 3e5358bd..f1df6f04 100644 --- a/api/app/core/memory/agent/utils/type_classifier.py +++ b/api/app/core/memory/agent/utils/type_classifier.py @@ -1,10 +1,9 @@ """ Type classification utility for distinguishing read/write operations. """ -from app.core.config import settings from app.core.logging_config import get_agent_logger, log_prompt_rendering from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import read_template_file +from app.core.memory.agent.utils.messages_tools import read_template_file from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from jinja2 import Template diff --git a/api/app/core/memory/agent/utils/write_to_database.py b/api/app/core/memory/agent/utils/write_to_database.py deleted file mode 100644 index bd78fe9d..00000000 --- a/api/app/core/memory/agent/utils/write_to_database.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import uuid -from datetime import datetime -from typing import Any -from sqlalchemy.orm import Session -import logging -import json - -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo - -logger = logging.getLogger(__name__) - -async def write_to_database(host_id: uuid.UUID, data: Any) -> str: - """ - 将数据写入数据库 - :param host_id: 宿主 ID - :param data: 要写入的数据 - :return: 写入数据库的结果 - """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - if isinstance(data, (dict, list)): - serialized = json.dumps(data, ensure_ascii=False) - elif isinstance(data, str): - serialized = data - else: - serialized = str(data) - - new_retrieval_info = RetrievalInfo( - # host_id=host_id, - host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"), - retrieve_info=serialized, - created_at=datetime.now() - ) - db.add(new_retrieval_info) - db.commit() - logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}") - return "success to write data to database" - except Exception as e: - db.rollback() - logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index f09b35e8..53c941ad 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -7,14 +7,12 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally. import time from datetime import datetime +from dotenv import load_dotenv + from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, -) +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context @@ -23,7 +21,7 @@ from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv + load_dotenv() diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 2aa286ba..cab6cacd 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -1,48 +1,15 @@ import asyncio -import os -import sys -from typing import List, Tuple - -from neo4j import GraphDatabase -from pydantic import BaseModel, Field - -# ------------------- 自包含路径解析 ------------------- -# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。 -try: - # 假设脚本在 /path/to/project/src/analytics/ - # 上升3个级别以到达项目根目录。 - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) - src_path = os.path.join(project_root, 'src') - - # 将 'src' 和 'project_root' 都添加到路径中。 - # 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。 - # 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。 - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -except NameError: - # 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案 - project_root = os.path.abspath(os.path.join(os.getcwd())) - src_path = os.path.join(project_root, 'src') - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -# --------------------------------------------------------------------- - -# 现在路径已经配置好,我们可以使用绝对导入 import json +import os +from typing import List, Tuple from app.core.config import settings from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context +from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.memory_config_service import MemoryConfigService +from pydantic import BaseModel, Field -#TODO: Fix this -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") # 定义用于LLM结构化输出的Pydantic模型 class FilteredTags(BaseModel): @@ -52,34 +19,45 @@ class FilteredTags(BaseModel): async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 + + Args: + tags: 原始标签列表 + group_id: 用户组ID,用于获取配置 + + Returns: + 筛选后的标签列表 + + Raises: + ValueError: 如果无法获取有效的LLM配置 """ try: # Get config_id using get_end_user_connected_config with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + + if not config_id: + raise ValueError( + f"No memory_config_id found for group_id: {group_id}. " + "Please ensure the user has a valid memory configuration." ) - connected_config = get_end_user_connected_config(group_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) + + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}. " + "Please configure a valid LLM model." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) # 3. 构建Prompt tag_list_str = ", ".join(tags) @@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags -def get_db_connection(): - """ - 使用项目的标准配置方法建立与Neo4j数据库的连接。 - """ - # 从全局配置获取 Neo4j 连接信息 - uri = settings.NEO4J_URI - user = settings.NEO4J_USERNAME - - # 密码必须为了安全从环境变量加载 - password = os.getenv("NEO4J_PASSWORD") - - if not uri or not user: - raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。") - if not password: - raise ValueError("NEO4J_PASSWORD 环境变量未设置。") - - # 为此脚本使用同步驱动 - return GraphDatabase.driver(uri, auth=(user, password)) - -def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_raw_tags_from_db( + connector: Neo4jConnector, + group_id: str, + limit: int, + by_user: bool = False +) -> List[Tuple[str, int]]: """ + TODO: not accurate tag extraction 从数据库查询原始的、未经过滤的实体标签及其频率。 + + 使用项目的Neo4jConnector进行查询,遵循仓储模式。 Args: + connector: Neo4j连接器实例 group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Returns: + List[Tuple[str, int]]: 标签名称和频率的元组列表 """ names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] @@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li "LIMIT $limit" ) - driver = None - try: - driver = get_db_connection() - with driver.session() as session: - result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude) - return [(record["name"], record["frequency"]) for record in result] - finally: - if driver: - driver.close() + # 使用项目的Neo4jConnector执行查询 + results = await connector.execute_query( + query, + id=group_id, + limit=limit, + names_to_exclude=names_to_exclude + ) + + return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 Args: - group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id + group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Raises: + ValueError: 如果group_id未提供或为空 """ - # 默认从环境变量读取 - group_id = group_id or DEFAULT_GROUP_ID - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user) - if not raw_tags_with_freq: - return [] - - raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - - # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) - - # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) - final_tags = [] - for tag, freq in raw_tags_with_freq: - if tag in meaningful_tag_names: - final_tags.append((tag, freq)) - - return final_tags - -if __name__ == "__main__": - print("开始获取热门记忆标签...") + # 验证group_id必须提供且不为空 + if not group_id or not group_id.strip(): + raise ValueError( + "group_id is required. Please provide a valid group_id or user_id." + ) + + # 使用项目的Neo4jConnector + connector = Neo4jConnector() try: - # 直接使用环境变量中的 group_id - group_id_to_query = DEFAULT_GROUP_ID - # 使用 asyncio.run 来执行异步主函数 - top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query)) + # 1. 从数据库获取原始排名靠前的标签 + raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) + if not raw_tags_with_freq: + return [] - if top_tags: - print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):") - for tag, frequency in top_tags: - print(f"- {tag} (数量: {frequency})") + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - # --- 将结果写入统一的 Signboard.json 到 logs/memory-output --- - from app.core.config import settings - settings.ensure_memory_output_dir() - signboard_path = settings.get_memory_output_path("Signboard.json") - payload = { - "group_id": group_id_to_query, - "hot_tags": [{"name": t, "frequency": f} for t, f in top_tags] - } - try: - existing = {} - if os.path.exists(signboard_path): - with open(signboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["hot_memory_tags"] = payload - with open(signboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {signboard_path} -> hot_memory_tags") - except Exception as e: - print(f"写入 Signboard.json 失败: {e}") - else: - print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。") - except Exception as e: - print(f"执行过程中发生严重错误: {e}") - print("请检查:") - print("1. Neo4j数据库服务是否正在运行。") - print("2. 'config.json'中的配置是否正确。") - print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。") + # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) + + # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) + final_tags = [] + for tag, freq in raw_tags_with_freq: + if tag in meaningful_tag_names: + final_tags.append((tag, freq)) + + return final_tags + finally: + # 确保关闭连接 + await connector.close() diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index ae2b9cfa..91e47eae 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") return results -# ============================================================================ -# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考 -# ============================================================================ -# def rerank_hybrid_results( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10 -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid search results by combining BM25 and embedding scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# -# Returns: -# Reranked results with combined scores -# """ -# reranked = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# # Create a combined pool of unique items -# combined_items = {} -# -# # Add keyword results with BM25 scores -# for item in keyword_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# combined_items[item_id]["embedding_score"] = 0 # Default -# -# # Add or update with embedding results -# for item in embedding_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# if item_id in combined_items: -# # Update existing item with embedding score -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# # New item from embedding search only -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 # Default -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# -# # Calculate combined scores and rank -# for item_id, item in combined_items.items(): -# bm25_score = item.get("bm25_score", 0) -# embedding_score = item.get("embedding_score", 0) -# -# # Combined score: weighted average of normalized scores -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# item["combined_score"] = combined_score -# -# # Keep original score for reference -# if "score" not in item and bm25_score > 0: -# item["score"] = bm25_score -# elif "score" not in item and embedding_score > 0: -# item["score"] = embedding_score -# -# # Sort by combined score and limit results -# sorted_items = sorted( -# combined_items.values(), -# key=lambda x: x.get("combined_score", 0), -# reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked - -# def rerank_with_forgetting_curve( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10, -# forgetting_config: ForgettingEngineConfig | None = None, -# now: datetime | None = None, -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid results with a forgetting curve applied to combined scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度) -# -# The forgetting curve reduces scores for older memories or weaker connections. -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# forgetting_config: Configuration for the forgetting engine -# now: Optional current time override for testing -# -# Returns: -# Reranked results with combined and final scores (after forgetting) -# """ -# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) -# now_dt = now or datetime.now() -# -# reranked: Dict[str, List[Dict[str, Any]]] = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# combined_items: Dict[str, Dict[str, Any]] = {} -# -# # Combine two result sets by ID -# for src_items, is_embedding in ( -# (keyword_items, False), (embedding_items, True) -# ): -# for item in src_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if not item_id: -# continue -# existing = combined_items.get(item_id) -# if not existing: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 -# combined_items[item_id]["embedding_score"] = 0 -# # Update normalized score from the right source -# if is_embedding: -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# -# # Calculate scores and apply forgetting weights -# for item_id, item in combined_items.items(): -# bm25_score = float(item.get("bm25_score", 0) or 0) -# embedding_score = float(item.get("embedding_score", 0) or 0) -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# -# # Estimate time elapsed in days -# dt = _parse_datetime(item.get("created_at")) -# if dt is None: -# time_elapsed_days = 0.0 -# else: -# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) -# -# # Memory strength (currently set to default value) -# memory_strength = 1.0 -# forgetting_weight = engine.calculate_weight( -# time_elapsed=time_elapsed_days, memory_strength=memory_strength -# ) -# final_score = combined_score * forgetting_weight -# item["combined_score"] = final_score -# -# sorted_items = sorted( -# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked +def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Remove duplicate items from search results based on content. + + Deduplication strategy: + 1. First try to deduplicate by ID (id, uuid, or chunk_id) + 2. Then deduplicate by content hash (text, content, statement, or name fields) + + Args: + items: List of search result items + + Returns: + Deduplicated list of items, preserving the order of first occurrence + """ + seen_ids = set() + seen_content = set() + deduplicated = [] + + for item in items: + # Try multiple ID fields to identify unique items + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + + # Extract content from various possible fields + content = ( + item.get("text") or + item.get("content") or + item.get("statement") or + item.get("name") or + "" + ) + + # Normalize content for comparison (strip whitespace and lowercase) + normalized_content = str(content).strip().lower() if content else "" + + # Check if we've seen this ID or content before + is_duplicate = False + + if item_id and item_id in seen_ids: + is_duplicate = True + elif normalized_content and normalized_content in seen_content: + # Only check content duplication if content is not empty + is_duplicate = True + + if not is_duplicate: + # Mark as seen + if item_id: + seen_ids.add(item_id) + if normalized_content: # Only track non-empty content + seen_content.add(normalized_content) + + deduplicated.append(item) + + return deduplicated def rerank_with_activation( @@ -364,7 +245,7 @@ def rerank_with_activation( keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - # 步骤 2: 按 ID 合并结果 + # 步骤 2: 按 ID 合并结果(去重) combined_items: Dict[str, Dict[str, Any]] = {} # 添加关键词结果 @@ -507,6 +388,9 @@ def rerank_with_activation( # 无激活值:使用内容相关性分数 item["final_score"] = item.get("base_score", 0) + # 最终去重确保没有重复项 + sorted_items = _deduplicate_results(sorted_items) + reranked[category] = sorted_items return reranked @@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id( ) return {"chunks": chunks} - -# def main(): -# """Main entry point for the hybrid graph search CLI. - -# Parses command line arguments and executes search with specified parameters. -# Supports keyword, embedding, and hybrid search modes. -# """ -# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") -# parser.add_argument( -# "--query", "-q", required=True, help="Free-text query to search" -# ) -# parser.add_argument( -# "--search-type", -# "-t", -# choices=["keyword", "embedding", "hybrid"], -# default="hybrid", -# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" -# ) -# parser.add_argument( -# "--config-id", -# "-c", -# type=int, -# required=True, -# help="Database configuration ID (required)", -# ) -# parser.add_argument( -# "--group-id", -# "-g", -# default=None, -# help="Optional group_id to filter results (default: None)", -# ) -# parser.add_argument( -# "--limit", -# "-k", -# type=int, -# default=5, -# help="Max number of results per type (default: 5)", -# ) -# parser.add_argument( -# "--include", -# "-i", -# nargs="+", -# default=["statements", "chunks", "entities", "summaries"], -# choices=["statements", "chunks", "entities", "summaries"], -# help="Which targets to search for embedding search (default: statements chunks entities summaries)" -# ) -# parser.add_argument( -# "--output", -# "-o", -# default="search_results.json", -# help="Path to save the search results JSON (default: search_results.json)", -# ) -# parser.add_argument( -# "--rerank-alpha", -# "-a", -# type=float, -# default=0.6, -# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", -# ) -# parser.add_argument( -# "--forgetting-rerank", -# action="store_true", -# help="Apply forgetting curve during reranking for hybrid search.", -# ) -# parser.add_argument( -# "--llm-rerank", -# action="store_true", -# help="Apply LLM-based reranking for hybrid search.", -# ) -# args = parser.parse_args() - -# # Load memory config from database -# from app.services.memory_config_service import MemoryConfigService -# memory_config = MemoryConfigService.load_memory_config(args.config_id) - -# asyncio.run( -# run_hybrid_search( -# query_text=args.query, -# search_type=args.search_type, -# group_id=args.group_id, -# limit=args.limit, -# include=args.include, -# output_path=args.output, -# memory_config=memory_config, -# rerank_alpha=args.rerank_alpha, -# use_forgetting_rerank=args.forgetting_rerank, -# use_llm_rerank=args.llm_rerank, -# ) -# ) - - -# if __name__ == "__main__": -# main() diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index e9fb8855..d39c9dbb 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -18,21 +18,13 @@ from enum import Enum from typing import Any, Dict, List, Optional from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config.get_data import ( extract_and_process_changes, get_data, get_data_statement, ) -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.prompt.template_render import ( - render_evaluate_prompt, - render_reflexion_prompt, -) from app.core.models.base import RedBearModelConfig -from app.core.response_utils import success from app.repositories.neo4j.cypher_queries import ( - UPDATE_STATEMENT_INVALID_AT, neo4j_query_all, neo4j_query_part, neo4j_statement_all, @@ -160,12 +152,11 @@ class ReflectionEngine: self.neo4j_connector = Neo4jConnector() if self.llm_client is None: - from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context with get_db_context() as db: factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) + self.llm_client = factory.get_llm_client(self.config.model_id) elif isinstance(self.llm_client, str): # 如果 llm_client 是字符串(model_id),则用它初始化客户端 from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -263,25 +254,23 @@ class ReflectionEngine: # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) - print(100 * '-') - print(conflict_data) - print(100 * '-') - # # 检查是否真的有冲突 - conflicts_found='' + conflict_list=[] + for i in conflict_data: + conflict_list.append(i['data']) - conflicts_found='' + + + conflicts_found=0 # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) + solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) + if not solved_data: return ReflectionResult( success=False, - message="反思失败,未解决冲突", + message=f"没有{self.config.baseline}相关的冲突数据", conflicts_found=conflicts_found, execution_time=asyncio.get_event_loop().time() - start_time ) - print(100 * '*') - print(solved_data) - print(100 * '*') conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") @@ -386,7 +375,7 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - conflicts_found='' + conflicts_found = 0 # 初始化为整数0而不是空字符串 REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} # Clearn conflict_data,And memory_verify和quality_assessment cleaned_conflict_data = [] @@ -414,7 +403,7 @@ class ReflectionEngine: cleaned_conflict_data_.append(cleaned_item) print(cleaned_conflict_data_) # 3. 解决冲突 - solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) + solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) if not solved_data: return ReflectionResult( success=False, @@ -739,4 +728,3 @@ class ReflectionEngine: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/config/__init__.py b/api/app/core/memory/utils/config/__init__.py index f69c13a2..9eef888c 100644 --- a/api/app/core/memory/utils/config/__init__.py +++ b/api/app/core/memory/utils/config/__init__.py @@ -14,28 +14,8 @@ from .config_utils import ( get_pruning_config, get_voice_config, ) - -# DEPRECATED: Global configuration variables removed -# Use MemoryConfig objects with dependency injection instead -# from .definitions import ( -# CONFIG, # DEPRECATED - empty dict for backward compatibility -# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility -# PROJECT_ROOT, # Still needed for file paths -# reload_configuration_from_database, # DEPRECATED - returns False -# ) -# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection from .get_data import get_data -# litellm_config 需要时动态导入,避免循环依赖 -# from .litellm_config import ( -# LiteLLMConfig, -# setup_litellm_enhanced, -# get_usage_summary, -# print_usage_summary, -# get_instant_qps, -# print_instant_qps, -# ) - __all__ = [ # config_utils "get_model_config", @@ -45,18 +25,5 @@ __all__ = [ "get_pruning_config", "get_picture_config", "get_voice_config", - # definitions (DEPRECATED - use MemoryConfig objects instead) - # "CONFIG", # DEPRECATED - # "RUNTIME_CONFIG", # DEPRECATED - # "PROJECT_ROOT", - # "reload_configuration_from_database", # DEPRECATED - # get_data "get_data", - # litellm_config - 需要时从 .litellm_config 直接导入 - # "LiteLLMConfig", - # "setup_litellm_enhanced", - # "get_usage_summary", - # "print_usage_summary", - # "get_instant_qps", - # "print_instant_qps", ] diff --git a/api/app/core/memory/utils/config/config_optimization.py b/api/app/core/memory/utils/config/config_optimization.py deleted file mode 100644 index 41848a80..00000000 --- a/api/app/core/memory/utils/config/config_optimization.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -配置管理优化模块 - -提供可选的配置管理优化功能,包括: -- LRU 缓存策略 -- 缓存预热 -- 缓存监控指标 -- 动态 TTL 策略 -- 配置版本控制 - -这些优化是可选的,当前的基础实现已经满足大多数需求。 -""" -import logging -import statistics -import threading -from collections import OrderedDict -from datetime import datetime, timedelta -from typing import Dict, Any, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -class LRUConfigCache: - """ - LRU(Least Recently Used)配置缓存 - - 当缓存达到最大容量时,自动淘汰最少使用的配置 - """ - - def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)): - """ - 初始化 LRU 缓存 - - Args: - max_size: 最大缓存容量 - ttl: 缓存过期时间 - """ - self.max_size = max_size - self.ttl = ttl - self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() - self._timestamps: Dict[str, datetime] = {} - self._lock = threading.RLock() - - # 统计信息 - self._stats = { - 'hits': 0, - 'misses': 0, - 'evictions': 0, - 'load_times': [] - } - - def get(self, config_id: str) -> Optional[Dict[str, Any]]: - """ - 获取配置(如果存在且未过期) - - Args: - config_id: 配置 ID - - Returns: - 配置字典,如果不存在或已过期则返回 None - """ - with self._lock: - if config_id not in self._cache: - self._stats['misses'] += 1 - return None - - # 检查是否过期 - timestamp = self._timestamps.get(config_id) - if timestamp and (datetime.now() - timestamp) >= self.ttl: - # 过期,移除 - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - self._stats['misses'] += 1 - return None - - # 命中,移动到末尾(标记为最近使用) - self._cache.move_to_end(config_id) - self._stats['hits'] += 1 - return self._cache[config_id] - - def put(self, config_id: str, config: Dict[str, Any]) -> None: - """ - 添加或更新配置 - - Args: - config_id: 配置 ID - config: 配置字典 - """ - with self._lock: - if config_id in self._cache: - # 更新现有配置 - self._cache.move_to_end(config_id) - else: - # 添加新配置 - if len(self._cache) >= self.max_size: - # 缓存已满,移除最旧的配置 - oldest_id, _ = self._cache.popitem(last=False) - self._timestamps.pop(oldest_id, None) - self._stats['evictions'] += 1 - logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}") - - self._cache[config_id] = config - self._timestamps[config_id] = datetime.now() - - def clear(self, config_id: Optional[str] = None) -> None: - """ - 清除缓存 - - Args: - config_id: 如果指定,只清除该配置;否则清除所有 - """ - with self._lock: - if config_id: - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - else: - self._cache.clear() - self._timestamps.clear() - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 统计信息字典 - """ - with self._lock: - total = self._stats['hits'] + self._stats['misses'] - hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0 - - return { - 'cache_size': len(self._cache), - 'max_size': self.max_size, - 'total_requests': total, - 'cache_hits': self._stats['hits'], - 'cache_misses': self._stats['misses'], - 'evictions': self._stats['evictions'], - 'hit_rate': hit_rate, - 'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0 - } - - def record_load_time(self, load_time_ms: float) -> None: - """ - 记录加载时间 - - Args: - load_time_ms: 加载时间(毫秒) - """ - with self._lock: - self._stats['load_times'].append(load_time_ms) - # 只保留最近 1000 次的记录 - if len(self._stats['load_times']) > 1000: - self._stats['load_times'] = self._stats['load_times'][-1000:] - - -class ConfigCacheWarmer: - """ - 配置缓存预热器 - - 在系统启动时预加载常用配置,减少首次请求延迟 - """ - - @staticmethod - def warmup(config_ids: List[str], load_func) -> Dict[str, bool]: - """ - 预热缓存 - - Args: - config_ids: 要预加载的配置 ID 列表 - load_func: 配置加载函数 - - Returns: - 每个配置的加载结果 - """ - results = {} - - logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置") - - for config_id in config_ids: - try: - result = load_func(config_id) - results[config_id] = result - if result: - logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}") - else: - logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}") - except Exception as e: - logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}") - results[config_id] = False - - success_count = sum(1 for r in results.values() if r) - logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功") - - return results - - -class DynamicTTLStrategy: - """ - 动态 TTL 策略 - - 根据配置类型和更新频率动态调整缓存过期时间 - """ - - # 预定义的 TTL 策略 - TTL_STRATEGIES = { - 'production': timedelta(minutes=30), # 生产配置较稳定 - 'staging': timedelta(minutes=15), # 预发布配置中等稳定 - 'development': timedelta(minutes=5), # 开发配置频繁变化 - 'testing': timedelta(minutes=1), # 测试配置快速过期 - 'default': timedelta(minutes=5) # 默认策略 - } - - @classmethod - def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta: - """ - 获取配置的 TTL - - Args: - config_id: 配置 ID - config_type: 配置类型(production/staging/development/testing) - - Returns: - TTL 时间间隔 - """ - if config_type and config_type in cls.TTL_STRATEGIES: - return cls.TTL_STRATEGIES[config_type] - - # 根据 config_id 推断类型 - if 'prod' in config_id.lower(): - return cls.TTL_STRATEGIES['production'] - elif 'stag' in config_id.lower(): - return cls.TTL_STRATEGIES['staging'] - elif 'dev' in config_id.lower(): - return cls.TTL_STRATEGIES['development'] - elif 'test' in config_id.lower(): - return cls.TTL_STRATEGIES['testing'] - - return cls.TTL_STRATEGIES['default'] - - -class ConfigVersionManager: - """ - 配置版本管理器 - - 跟踪配置版本,当配置更新时自动失效旧版本缓存 - """ - - def __init__(self): - self._versions: Dict[str, str] = {} - self._lock = threading.RLock() - - def get_version(self, config_id: str) -> Optional[str]: - """ - 获取配置版本 - - Args: - config_id: 配置 ID - - Returns: - 版本号,如果不存在则返回 None - """ - with self._lock: - return self._versions.get(config_id) - - def set_version(self, config_id: str, version: str) -> None: - """ - 设置配置版本 - - Args: - config_id: 配置 ID - version: 版本号 - """ - with self._lock: - old_version = self._versions.get(config_id) - self._versions[config_id] = version - - if old_version and old_version != version: - logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}") - - def check_version(self, config_id: str, cached_version: Optional[str]) -> bool: - """ - 检查缓存版本是否有效 - - Args: - config_id: 配置 ID - cached_version: 缓存的版本号 - - Returns: - True 如果版本匹配,False 如果版本不匹配或不存在 - """ - with self._lock: - current_version = self._versions.get(config_id) - - if not current_version or not cached_version: - return False - - return current_version == cached_version - - def invalidate(self, config_id: str) -> None: - """ - 使配置版本失效 - - Args: - config_id: 配置 ID - """ - with self._lock: - if config_id in self._versions: - # 生成新版本号 - import uuid - new_version = str(uuid.uuid4()) - self._versions[config_id] = new_version - logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}") - - -class CacheMonitor: - """ - 缓存监控器 - - 提供缓存性能监控和报告功能 - """ - - def __init__(self, cache: LRUConfigCache): - self.cache = cache - - def get_report(self) -> str: - """ - 生成缓存性能报告 - - Returns: - 格式化的报告字符串 - """ - stats = self.cache.get_stats() - - report = f""" -配置缓存性能报告 -================ -缓存容量: {stats['cache_size']}/{stats['max_size']} -总请求数: {stats['total_requests']} -缓存命中: {stats['cache_hits']} -缓存未命中: {stats['cache_misses']} -缓存命中率: {stats['hit_rate']:.2f}% -淘汰次数: {stats['evictions']} -平均加载时间: {stats['avg_load_time']:.2f}ms -""" - return report - - def log_stats(self) -> None: - """记录统计信息到日志""" - stats = self.cache.get_stats() - logger.info( - f"[CacheMonitor] 缓存统计 - " - f"容量: {stats['cache_size']}/{stats['max_size']}, " - f"命中率: {stats['hit_rate']:.2f}%, " - f"淘汰: {stats['evictions']}" - ) - - -# 使用示例 -def example_usage(): - """ - 优化功能使用示例 - """ - # 1. 使用 LRU 缓存 - lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5)) - - # 获取配置 - config = lru_cache.get("config_001") - if config is None: - # 缓存未命中,从数据库加载 - config = {"llm_name": "openai/gpt-4"} - lru_cache.put("config_001", config) - - # 2. 预热缓存 - def load_config(config_id): - # 实际的配置加载逻辑 - return True - - warmer = ConfigCacheWarmer() - results = warmer.warmup(["config_001", "config_002"], load_config) - - # 3. 动态 TTL - ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production") - print(f"TTL: {ttl}") - - # 4. 版本管理 - version_manager = ConfigVersionManager() - version_manager.set_version("config_001", "v1.0.0") - - # 检查版本 - is_valid = version_manager.check_version("config_001", "v1.0.0") - - # 5. 监控 - monitor = CacheMonitor(lru_cache) - print(monitor.get_report()) - - -if __name__ == "__main__": - example_usage() diff --git a/api/app/core/memory/utils/config/definitions.py b/api/app/core/memory/utils/config/definitions.py deleted file mode 100644 index fc07c2cc..00000000 --- a/api/app/core/memory/utils/config/definitions.py +++ /dev/null @@ -1,268 +0,0 @@ -# """ -# 配置加载模块 - DEPRECATED - -# ⚠️ DEPRECATION NOTICE ⚠️ -# This module is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.schemas.memory_config_schema.MemoryConfig for configuration objects -# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id) - -# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED -# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED -# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED -# """ -# import json -# import os -# import threading -# from datetime import datetime, timedelta -# from typing import Any, Dict, Optional - -# #TODO: Fix this - -# try: -# from dotenv import load_dotenv -# load_dotenv() -# except Exception: -# pass - -# # Import unified configuration system -# try: -# from app.core.config import settings -# USE_UNIFIED_CONFIG = True -# except ImportError: -# USE_UNIFIED_CONFIG = False -# settings = None - -# # PROJECT_ROOT 应该指向 app/core/memory/ 目录 -# # __file__ = app/core/memory/utils/config/definitions.py -# # os.path.dirname(__file__) = app/core/memory/utils/config -# # os.path.dirname(...) = app/core/memory/utils -# # os.path.dirname(...) = app/core/memory -# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# # DEPRECATED: Global configuration lock removed -# # Use MemoryConfig objects with dependency injection instead - -# # DEPRECATED: Legacy config.json loading removed -# # Use MemoryConfig objects with dependency injection instead -# CONFIG = {} - -# DEFAULT_VALUES = { -# "llm_name": "openai/qwen-plus", -# "embedding_name": "openai/nomic-embed-text:v1.5", -# "chunker_strategy": "RecursiveChunker", -# "group_id": "group_123", -# "user_id": "default_user", -# "apply_id": "default_apply", -# "llm_agent_name": "openai/qwen-plus", -# "llm_verify_name": "openai/qwen-plus", -# "llm_image_recognition": "openai/qwen-plus", -# "llm_voice_recognition": "openai/qwen-plus", -# "prompt_level": "DEBUG", -# "reflexion_iteration_period": "3", -# "reflexion_range": "retrieval", -# "reflexion_baseline": "TIME", -# } - -# # DEPRECATED: Legacy global variables for backward compatibility only -# # These will be removed in a future version -# # Use MemoryConfig objects with dependency injection instead -# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" -# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) - - -# # 阶段 1: 从 runtime.json 加载配置(路径 A) -# def _load_from_runtime_json() -> Dict[str, Any]: -# """ -# DEPRECATED: Legacy runtime.json loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Dict[str, Any]: Empty configuration (legacy support only) -# """ -# import warnings -# warnings.warn( -# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return {"selections": {}} - - -# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 -# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 -# # 保留此函数仅为向后兼容 -# def _load_from_database() -> Optional[Dict[str, Any]]: -# """ -# DEPRECATED: Legacy database configuration loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[Dict[str, Any]]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED -# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: -# """ -# DEPRECATED: 将运行时配置暴露为全局常量供项目使用 - -# ⚠️ This function is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.core.memory_config.config.MemoryConfig for configuration objects -# - Pass configuration objects as parameters instead of using global variables - -# Args: -# runtime_cfg: 运行时配置字典 -# """ -# import warnings -# warnings.warn( -# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# # Keep minimal global state for backward compatibility only -# # These will be removed in a future version -# global RUNTIME_CONFIG, SELECTIONS - -# RUNTIME_CONFIG = runtime_cfg -# SELECTIONS = RUNTIME_CONFIG.get("selections", {}) - -# # All other global variables have been removed -# # Use MemoryConfig objects instead - - -# # 初始化:使用统一配置加载器 -# def _initialize_configuration() -> None: -# """ -# DEPRECATED: Legacy configuration initialization - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. -# """ -# import warnings -# warnings.warn( -# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# # Initialize with empty configuration for backward compatibility -# _expose_runtime_constants({"selections": {}}) - - -# # 模块加载时自动初始化配置 -# _initialize_configuration() - -# # DEPRECATED: Global variables removed -# # These variables have been eliminated in favor of dependency injection -# # Use MemoryConfig objects instead of accessing global variables - - -# # 公共 API:动态重新加载配置 -# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: -# """ -# DEPRECATED: Legacy configuration reloading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) -# force_reload: Force reload flag (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - - - - -# def get_current_config_id() -> Optional[str]: -# """ -# DEPRECATED: Legacy config ID retrieval - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[str]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# def ensure_fresh_config(config_id = None) -> bool: -# """ -# DEPRECATED: Legacy configuration freshness check - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - diff --git a/api/app/core/memory/utils/config/get_example_data.py b/api/app/core/memory/utils/config/get_example_data.py deleted file mode 100644 index c466645b..00000000 --- a/api/app/core/memory/utils/config/get_example_data.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import re -import uuid -import random -import string -from typing import List, Dict, Optional - -# 生成包含字母(大小写)和数字的随机字符串 -def generate_random_string(length=16): - characters = string.ascii_letters + string.digits - return ''.join(random.choice(characters) for _ in range(length)) - -def get_example_data() -> List[Dict[str, Optional[str]]]: - """ - 从句子提取日志中获取数据 - Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。 - Created At: 2025-11-28 19:28:38.256421 - Expired At: None - Valid At: None - Invalid At: None - 将数据构造成如下形式: - [ - { - "id":id, - "group_id":group_id, - "statement": Content, - "created_at": Created At, - "expired_at": Expired At, - "valid_at": Valid At, - "invalid_at": Invalid At, - "chunk_id": "86da9022710c40eaa5f518a294c398d2", - "entity_ids": [] - }, - ... - ] - """ - # 获取日志文件路径 - log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt") - - # 检查文件是否存在 - if not os.path.exists(log_file_path): - return [] - - # 读取日志文件 - with open(log_file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 解析数据 - results = [] - - # 使用正则表达式分割每个 Statement - statement_blocks = re.split(r"Statement \d+:", content) - - for block in statement_blocks[1:]: # 跳过第一个空块 - # 提取各个字段 - id_match = re.search(r"Id:\s*(.+?)(?=\n)", block) - group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block) - statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block) - created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block) - expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block) - valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block) - invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block) - chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block) - - # 构造字典 - if statement_match: - statement_data = { - "id": id_match.group(1).strip() if id_match else generate_random_string(), - "group_id": group_id_match.group(1).strip() if group_id_match else "group_example", - "statement": statement_match.group(1).strip(), - "created_at": created_at_match.group(1).strip() if created_at_match else None, - "expired_at": expired_at_match.group(1).strip() if expired_at_match else None, - "valid_at": valid_at_match.group(1).strip() if valid_at_match else None, - "invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None, - "chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example", - "entity_ids": [] - } - - # 将 "None" 字符串转换为 None - for key in ["created_at", "expired_at", "valid_at", "invalid_at"]: - if statement_data[key] == "None": - statement_data[key] = None - - results.append(statement_data) - - return results - - -if __name__ == "__main__": - print(f"获取数据如下:\n {get_example_data()}") \ No newline at end of file diff --git a/api/app/core/memory/utils/config/litellm_config.py b/api/app/core/memory/utils/config/litellm_config.py deleted file mode 100644 index dbf991a8..00000000 --- a/api/app/core/memory/utils/config/litellm_config.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring -""" - -import litellm -from typing import Dict, Any, List -import json -from datetime import datetime, timedelta -import os -import time -from collections import defaultdict -import threading -from queue import Queue - -class LiteLLMConfig: - """Configuration class for LiteLLM with enhanced retry and tracking capabilities""" - - def __init__(self): - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], # Store precise timestamps - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] # Store QPS measurements over time - }) - self.start_time = datetime.now() - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - - # Rate limiting for AWS Bedrock (conservative limits) - self.rate_limits = { - 'bedrock': { - 'requests_per_minute': 2, # AWS Bedrock default is very low - 'requests_per_second': 0.033, # 2/60 = 0.033 RPS - 'last_request_time': 0, - 'request_queue': Queue(), - 'lock': threading.Lock() - } - } - self.rate_limiting_enabled = True - - def setup_enhanced_config(self, max_retries: int = 3): - """Configure LiteLLM with retry logic and instant QPS tracking""" - - litellm.num_retries = max_retries - litellm.request_timeout = 300 - - litellm.retry_policy = { - "RateLimitError": { - "max_retries": 5, - "exponential_backoff": True, - "initial_delay": 1, - "max_delay": 60, - "jitter": True - }, - "APIConnectionError": { - "max_retries": 3, - "exponential_backoff": True, - "initial_delay": 2, - "max_delay": 30, - "jitter": True - }, - "InternalServerError": { - "max_retries": 2, - "exponential_backoff": True, - "initial_delay": 5, - "max_delay": 60, - "jitter": True - }, - "BadRequestError": { - "max_retries": 1, - "exponential_backoff": False, - "initial_delay": 1, - "max_delay": 5 - } - } - - litellm.success_callback = [self._success_callback] - litellm.failure_callback = [self._failure_callback] - litellm.completion_cost_tracking = True - litellm.set_verbose = False - litellm.modify_params = True - - print("✅ LiteLLM configured with instant QPS tracking and rate limiting") - - def _success_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for successful requests with module-specific QPS tracking""" - try: - # Extract usage information - usage = completion_response.get('usage', {}) - model = kwargs.get('model', 'unknown') - - # Extract module information from metadata or model name - module = self._extract_module_name(kwargs, model) - - # Calculate cost - cost = 0.0 - try: - cost = litellm.completion_cost(completion_response) - except: - pass - - # Calculate duration - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Record usage data - usage_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "input_tokens": usage.get('prompt_tokens', 0), - "output_tokens": usage.get('completion_tokens', 0), - "total_tokens": usage.get('total_tokens', 0), - "cost": cost, - "duration_seconds": duration_seconds, - "status": "success" - } - - self.usage_data.append(usage_record) - - # Update module-specific stats for QPS tracking - self._update_module_stats(module, usage_record, success=True) - - # Print real-time feedback - print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s") - - except Exception as e: - print(f"Warning: Success callback failed: {e}") - - def _failure_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for failed requests with module-specific error tracking""" - try: - model = kwargs.get('model', 'unknown') - module = self._extract_module_name(kwargs, model) - - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Handle different error response formats - error_message = "Unknown error" - error_type = "UnknownError" - - # According to LiteLLM docs, completion_response contains the exception for failures - if completion_response is not None: - error_message = str(completion_response) - error_type = type(completion_response).__name__ - - # Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events) - elif 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - # Check for other error formats in kwargs - elif 'error' in kwargs: - error = kwargs['error'] - error_message = str(error) - error_type = type(error).__name__ - - # Check log_event_type to confirm this is a failure event - log_event_type = kwargs.get('log_event_type', '') - if log_event_type == 'failed_api_call' and 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - error_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "error": error_message, - "error_type": error_type, - "duration_seconds": duration_seconds, - "status": "failed" - } - - self.error_data.append(error_record) - - # Update module-specific stats for error tracking - self._update_module_stats(module, error_record, success=False) - - # Print error feedback - print(f"✗ {model}: {error_type} - {error_message[:100]}") - - except Exception as e: - print(f"Warning: Failure callback failed: {e}") - # Debug: print the actual parameters to understand the structure - print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}") - print(f"Debug - completion_response type: {type(completion_response)}") - print(f"Debug - completion_response: {completion_response}") - - def _should_rate_limit(self, model: str) -> bool: - """Check if the model should be rate limited""" - if not self.rate_limiting_enabled: - return False - return model.startswith('bedrock/') or 'bedrock' in model.lower() - - def _enforce_rate_limit(self, model: str): - """Enforce rate limiting for AWS Bedrock models""" - if not self._should_rate_limit(model): - return - - provider = 'bedrock' - if provider not in self.rate_limits: - return - - rate_config = self.rate_limits[provider] - - with rate_config['lock']: - current_time = time.time() - time_since_last = current_time - rate_config['last_request_time'] - min_interval = 1.0 / rate_config['requests_per_second'] - - if time_since_last < min_interval: - sleep_time = min_interval - time_since_last - print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}") - time.sleep(sleep_time) - - rate_config['last_request_time'] = time.time() - - def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str: - """Extract module name from request context""" - # Try to get module from metadata - metadata = kwargs.get('metadata', {}) - if 'module' in metadata: - return metadata['module'] - - # Try to infer from model name or other context - if 'claude' in model.lower(): - return 'bedrock_client' - elif 'gpt' in model.lower() or 'openai' in model.lower(): - return 'openai_client' - elif 'embed' in model.lower(): - return 'embedder' - else: - return 'unknown' - - def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool): - """Update module-specific statistics with instant QPS tracking""" - current_timestamp = time.time() - current_time = datetime.now() - - # Initialize module stats if first request - if self.module_stats[module]['start_time'] is None: - self.module_stats[module]['start_time'] = current_time - - # Update counters - self.module_stats[module]['requests'] += 1 - self.module_stats[module]['last_request_time'] = current_time - self.module_stats[module]['request_timestamps'].append(current_timestamp) - self.global_request_timestamps.append(current_timestamp) - - # Calculate instant QPS for this module - self._calculate_instant_qps(module, current_timestamp) - - # Calculate global instant QPS - self._calculate_global_instant_qps(current_timestamp) - - if success: - self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0) - self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0) - self.module_stats[module]['cost'] += record.get('cost', 0.0) - else: - self.module_stats[module]['errors'] += 1 - - def _calculate_instant_qps(self, module: str, current_timestamp: float): - """Calculate instant QPS for a specific module using sliding window""" - # Keep only timestamps from last 1 second for instant QPS - cutoff_time = current_timestamp - 1.0 - timestamps = self.module_stats[module]['request_timestamps'] - - # Remove old timestamps - self.module_stats[module]['request_timestamps'] = [ - ts for ts in timestamps if ts >= cutoff_time - ] - - # Calculate current QPS (requests in last second) - current_qps = len(self.module_stats[module]['request_timestamps']) - self.module_stats[module]['current_qps'] = current_qps - - # Update max QPS if current is higher - if current_qps > self.module_stats[module]['max_qps']: - self.module_stats[module]['max_qps'] = current_qps - - # Store QPS history (keep last 60 measurements) - self.module_stats[module]['qps_history'].append(current_qps) - if len(self.module_stats[module]['qps_history']) > 60: - self.module_stats[module]['qps_history'].pop(0) - - def _calculate_global_instant_qps(self, current_timestamp: float): - """Calculate global instant QPS across all modules""" - # Keep only timestamps from last 1 second - cutoff_time = current_timestamp - 1.0 - self.global_request_timestamps = [ - ts for ts in self.global_request_timestamps if ts >= cutoff_time - ] - - # Calculate current global QPS - current_global_qps = len(self.global_request_timestamps) - - # Update max global QPS - if current_global_qps > self.global_max_qps: - self.global_max_qps = current_global_qps - - def get_instant_qps(self, module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - if module: - if module in self.module_stats: - return { - 'module': module, - 'current_qps': self.module_stats[module]['current_qps'], - 'max_qps': self.module_stats[module]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0 - } - else: - return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0} - else: - # Return data for all modules plus global - result = { - 'global': { - 'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]), - 'max_qps': self.global_max_qps - }, - 'modules': {} - } - - for mod in self.module_stats: - result['modules'][mod] = { - 'current_qps': self.module_stats[mod]['current_qps'], - 'max_qps': self.module_stats[mod]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0 - } - - return result - - def get_usage_summary(self) -> Dict[str, Any]: - """Get essential usage statistics""" - if not self.usage_data: - return { - "total_requests": 0, - "total_cost": 0.0, - "error_rate": 0.0, - "message": "No usage data available" - } - - total_requests = len(self.usage_data) - total_errors = len(self.error_data) - total_cost = sum(record['cost'] for record in self.usage_data) - total_input_tokens = sum(record['input_tokens'] for record in self.usage_data) - total_output_tokens = sum(record['output_tokens'] for record in self.usage_data) - - # Calculate session duration - duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60 - - # Build module statistics - module_stats = {} - for module, stats in self.module_stats.items(): - if stats['requests'] > 0: - module_stats[module] = { - "requests": stats['requests'], - "errors": stats['errors'], - "success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0, - "tokens_in": stats['tokens_in'], - "tokens_out": stats['tokens_out'], - "cost": stats['cost'], - "current_qps": stats['current_qps'], - "max_qps": stats['max_qps'] - } - - return { - "session_duration_minutes": duration_minutes, - "total_requests": total_requests, - "total_errors": total_errors, - "error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0, - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_cost": total_cost, - "module_stats": module_stats, - "global_max_qps": self.global_max_qps - } - - def print_usage_summary(self): - """Print essential usage summary""" - stats = self.get_usage_summary() - - if stats.get('message'): - print(f"📊 {stats['message']}") - return - - print("\n📊 USAGE SUMMARY") - print(f"{'='*50}") - print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min") - print(f"📈 Requests: {stats['total_requests']}") - print(f"❌ Errors: {stats['total_errors']}") - print(f"💰 Cost: ${stats['total_cost']:.4f}") - print(f"🏆 Global Max QPS: {stats['global_max_qps']}") - - # Module statistics - if stats.get('module_stats'): - print("\n📦 MODULES:") - for module, mod_stats in stats['module_stats'].items(): - print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}") - - print(f"{'='*50}") - - def save_usage_data(self, filename: str = "litellm_usage.json"): - """Save usage data to JSON file""" - data = { - "summary": self.get_usage_summary(), - "detailed_usage": self.usage_data, - "errors": self.error_data, - "export_timestamp": datetime.now().isoformat() - } - - with open(filename, 'w') as f: - json.dump(data, f, indent=2) - - print(f"📁 Usage data saved to {filename}") - - def reset_tracking(self): - """Reset all tracking data""" - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] - }) - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - self.start_time = datetime.now() - print("🔄 All tracking data reset") - -# Global instance for easy access -litellm_config = LiteLLMConfig() - -def setup_litellm_enhanced(max_retries: int = 3): - """ - Quick setup function for LiteLLM enhanced configuration - - Args: - max_retries: Maximum number of retries for failed requests - """ - litellm_config.setup_enhanced_config(max_retries) - return litellm_config - -def get_usage_summary(): - """Get current usage summary""" - return litellm_config.get_usage_summary() - -def print_usage_summary(): - """Print current usage summary""" - litellm_config.print_usage_summary() - -def save_usage_data(filename: str = "litellm_usage.json"): - """Save usage data to file""" - litellm_config.save_usage_data(filename) - -def get_instant_qps(module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - return litellm_config.get_instant_qps(module) - -def print_instant_qps(module: str = None): - """Print instant QPS information""" - qps_data = get_instant_qps(module) - - print("\n⚡ INSTANT QPS MONITOR") - print(f"{'='*60}") - - if module: - print(f"Module: {qps_data['module']}") - print(f" Current QPS: {qps_data['current_qps']}") - print(f" Max QPS: {qps_data['max_qps']}") - print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}") - else: - # Global stats - global_data = qps_data.get('global', {}) - print("🌍 GLOBAL:") - print(f" Current QPS: {global_data.get('current_qps', 0)}") - print(f" Max QPS: {global_data.get('max_qps', 0)}") - - # Module stats - modules = qps_data.get('modules', {}) - if modules: - print("\n📦 MODULES:") - for mod, data in modules.items(): - print(f" {mod}:") - print(f" Current: {data['current_qps']} QPS") - print(f" Max: {data['max_qps']} QPS") - print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS") - - print(f"{'='*60}") - -def reset_tracking(): - """Reset all tracking data""" - litellm_config.reset_tracking() - -def get_module_stats() -> Dict[str, Dict[str, Any]]: - """Get detailed module statistics""" - summary = get_usage_summary() - return summary.get('module_stats', {}) diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index e649897a..5da6d4b5 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -24,7 +24,8 @@ - **身份冲突**: 同一实体被赋予不同类型或角色 - **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 混合冲突 -检测所有逻辑不一致或相互矛盾的记录。 +- 检测所有逻辑不一致或相互矛盾的记录。 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 **检测原则**: - 重点检查相同实体的记录 - 分析description字段语义冲突 diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index ed3aad32..99660aa4 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -63,7 +63,7 @@ **脱敏字段**: name、entity1_name、entity2_name、description、relationship ## 4. 处理流程 - +###如果存在冲突数据执行以下步骤,不存在返回【】在data中 ### 步骤1: 类型匹配验证 **匹配规则**: - baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点) @@ -78,7 +78,7 @@ ### 步骤2: 冲突数据分组 **分组策略**: -- 时间冲突组: 涉及用户时间的记录 +- 时间冲突组: 涉及用户时间的记录比如(生日在2月17...) - 活动时间冲突组: 同一活动不同时间的记录 - 事实冲突组: 同一实体不同属性的记录 - 其他冲突组: 其他类型冲突记录 @@ -97,11 +97,12 @@ ### 处理规则 ** baseline是TIME - -保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的 -** baseline不是TIME + - 只处理时间相关的内容,比如时间表达式、日期、时间点 + -保留正确记录不变修改错误记录的expired_at为当前时间,比如(2025-12-16T12:00:00) +** baseline是FACT或者HYBRID + - 处理不是时间相关的内容 - 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段, 如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field - **核心原则**: - 只输出需要修改的记录 - 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录 @@ -110,22 +111,26 @@ - 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} - 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 - 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容, - ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据 + ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据,如果存在隐私审核,隐私审核是true,也需要放到reflexion的reason字段和solution **变更记录格式**: ```json "change": [ { "field": [ - {"id":修改字段对应的ID} - {"statement_id":需要修改的对象对应的statement_id} - {"字段名1": ["修改前的值1","修改后的值1"]}, - {"字段名2": ["修改前的值2","修改后的值2"]} + {"id": "修改字段对应的ID"}, + {"字段名1": ["修改前的值1", "修改后的值1"]}, + {"字段名2": ["修改前的值2", "修改后的值2"]} ] } ] ``` +**resolved_memory格式说明**: +- 对于TIME类型冲突: 只需expired_at字段即可 +- 对于FACT/HYBRID类型冲突: 需要包含完整的记录对象(包括name、entity1_name、entity2_name、description、relationship等所有相关字段) +- resolved_memory中只包含需要修改的记录,不需要修改的记录不要包含在内 + **类型不匹配处理**: - 冲突类型与baseline不匹配时,resolved设为null - reflexion.reason说明类型不匹配原因 @@ -157,7 +162,8 @@ "conflict": true }, "reflexion": { - "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析, + "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,如果 + 隐私审核打开的时候如果存在冲突,分析该冲突的原因 不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种", "solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)" }, diff --git a/api/app/core/memory/utils/self_reflexion_utils/__init__.py b/api/app/core/memory/utils/self_reflexion_utils/__init__.py deleted file mode 100644 index 422a83e3..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思工具模块 - -本模块提供自我反思引擎的核心功能,包括: -- 记忆冲突判定 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api 迁移而来。 -""" - -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - -__all__ = ["conflict", "reflexion", "self_reflexion"] diff --git a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py b/api/app/core/memory/utils/self_reflexion_utils/evaluate.py deleted file mode 100644 index 4d1835cd..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -"""记忆冲突判定模块 - -本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。 -从 app.core.memory.src.data_config_api.evaluate 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_evaluate_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ConflictResultSchema -from pydantic import BaseModel - - -async def conflict(evaluate_data: List[Any]) -> List[Any]: - """ - Evaluates memory conflict using the evaluate.jinja2 template. - - Args: - evaluate_data: 反思数据列表。 - Returns: - 冲突记忆列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - print(f"====== 冲突判定开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ConflictResultSchema) - end_time = time.time() - print(f"冲突判定耗时: {end_time - start_time} 秒") - print(f"冲突判定原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。") - return [] - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/reflexion.py deleted file mode 100644 index 1b915118..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -"""反思执行模块 - -本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。 -从 app.core.memory.src.data_config_api.reflexion 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_reflexion_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ReflexionResultSchema -from pydantic import BaseModel - - -async def reflexion(ref_data: List[Any]) -> List[Any]: - """ - Reflexes on the given reference data using the reflexion.jinja2 template. - - Args: - ref_data: 反思数据列表。 - Returns: - 反思结果列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - - print(f"====== 反思开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ReflexionResultSchema) - end_time = time.time() - print(f"反思耗时: {end_time - start_time} 秒") - print(f"反思原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 反思输出解析失败,返回空列表以继续流程。") - return [] - # 统一返回为列表[dict],便于自我反思主流程更新数据库 - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化反思返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py deleted file mode 100644 index 934037b0..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思主执行模块 - -本模块提供自我反思引擎的主流程,包括: -- 获取反思数据 -- 冲突判断 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。 -""" - -import asyncio -import json -import logging -import os -import uuid -from typing import Any, Dict, List - -#TODO: Fix this - -# Default values (previously from definitions.py) -REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true" -REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3") -REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval") -REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME") - -from app.core.memory.utils.config.get_data import get_data -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from sqlalchemy.orm import Session - -# 并发限制(可通过环境变量覆盖) -CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5")) - -# 确保 INFO 级别日志输出到终端 -_root_logger = logging.getLogger() -if not _root_logger.handlers: - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -else: - _root_logger.setLevel(logging.INFO) - - -async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]: - """ - 根据反思范围获取判断的记忆数据。 - - Args: - host_id: 主机ID - Returns: - 符合反思范围的记忆数据列表。 - """ - if REFLEXION_RANGE == "partial": - return await get_data(host_id) - elif REFLEXION_RANGE == "all": - return [] - else: - raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}") - - -async def run_conflict(conflict_data: List[Any]) -> List[Any]: - """ - 判断反思数据中是否存在冲突。 - - Args: - conflict_data: 冲突数据列表。 - Returns: - 如果存在冲突则返回冲突记忆列表,否则返回空列表。 - """ - if not conflict_data: - return [] - - conflict_data = await conflict(conflict_data) - # 仅保留存在冲突的条目(conflict == True) - try: - return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True] - except Exception: - return [] - - -async def run_reflexion(reflexion_data: List[Any]) -> Any: - """ - 执行反思,解决冲突。 - - Args: - reflexion_data: 反思数据列表。 - Returns: - 解决冲突后的反思结果(由 LLM 返回)。 - """ - if not reflexion_data: - return [] - # 并行对每个冲突进行反思,整体缩短等待时间 - sem = asyncio.Semaphore(CONCURRENCY) - - async def _reflex_one(item: Any) -> Dict[str, Any] | None: - async with sem: - try: - result_list = await reflexion([item]) - if not result_list: - return None - obj = result_list[0] - if hasattr(obj, "model_dump"): - return obj.model_dump() - elif hasattr(obj, "dict"): - return obj.dict() - elif isinstance(obj, dict): - return obj - except Exception as e: - logging.warning(f"反思失败,跳过一项: {e}") - return None - - tasks = [_reflex_one(item) for item in reflexion_data] - results = await asyncio.gather(*tasks, return_exceptions=False) - return [r for r in results if r] - - -async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str: - """ - 更新记忆库,将解决冲突后的记忆更新到记忆库中。 - - Args: - solved_data: 解决冲突后的记忆(由 LLM 返回)。 - host_id: 主机ID - Returns: - 更新结果(成功或失败)。 - """ - flag = False - if not solved_data: - return "数据缺失,更新失败" - if not isinstance(solved_data, list): - return "数据格式错误,更新失败" - neo4j_connector = Neo4jConnector() - try: - print(f"====== 更新记忆开始 ======\n") - - sem = asyncio.Semaphore(CONCURRENCY) - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - async with sem: - try: - if not isinstance(item, dict): - return False - if not item: - return False - resolved = item.get("resolved") - if not isinstance(resolved, dict) or not resolved: - logging.warning(f"反思结果无可更新内容,跳过此项: {item}") - return False - resolved_mem = resolved.get("resolved_memory") - if not isinstance(resolved_mem, dict) or not resolved_mem: - logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}") - return False - group_id = resolved_mem.get("group_id") - id = resolved_mem.get("id") - # 使用 invalid_at 字段作为新的失效时间 - new_invalid_at = resolved_mem.get("invalid_at") - if not all([group_id, id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - await neo4j_connector.execute_query( - UPDATE_STATEMENT_INVALID_AT, - group_id=group_id, - id=id, - new_invalid_at=new_invalid_at, - ) - return True - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count} 条记忆") - flag = success_count > 0 - return "更新成功" if flag else "更新失败" - except Exception as e: - logging.error(f"更新记忆库失败: {e}") - return "更新失败" - finally: - if flag: # 删除数据库中的检索数据 - db: Session = next(get_db()) - try: - db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete() - db.commit() - logging.info(f"成功删除 {success_count} 条检索数据") - except Exception as e: - logging.error(f"删除数据库中的检索数据失败: {e}") - finally: - db.close() - - - -async def _append_json(label: str, data: Any) -> None: - """记录冲突记忆(后台线程写入,避免阻塞事件循环)""" - def _write(): - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - # 正确地在协程内等待后台线程执行,避免未等待的协程警告 - await asyncio.to_thread(_write) - - -async def self_reflexion(host_id: uuid.UUID) -> str: - """ - 自我反思引擎,执行反思流程。 - - Args: - host_id: 主机ID - - Returns: - 反思结果描述字符串 - """ - if not REFLEXION_ENABLED: - return "未开启反思..." - print(f"====== 自我反思流程开始 ======\n") - reflexion_data = await get_reflexion_data(host_id) - if not reflexion_data: - print(f"====== 自我反思流程结束 ======\n") - return "无反思数据,结束反思" - print(f"反思数据获取成功,共 {len(reflexion_data)} 条") - - conflict_data = await run_conflict(reflexion_data) - if not conflict_data: - print(f"====== 自我反思流程结束 ======\n") - return "无冲突,无需反思" - print(f"冲突记忆类型: {type(conflict_data)}") - await _append_json("conflict", conflict_data) - - solved_data = await run_reflexion(conflict_data) - if not solved_data: - print(f"====== 自我反思流程结束 ======\n") - return "反思失败,未解决冲突" - print(f"解决冲突后的记忆类型: {type(solved_data)}") - await _append_json("solved_data", solved_data) - - result = await update_memory(solved_data, host_id) - print(f"更新记忆库结果: {result}") - print(f"====== 自我反思流程结束 ======\n") - return result - - -if __name__ == "__main__": - import asyncio - # host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333") - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) diff --git a/api/app/core/rag/llm/cv_model.py b/api/app/core/rag/llm/cv_model.py index 24d4a35b..4207304a 100644 --- a/api/app/core/rag/llm/cv_model.py +++ b/api/app/core/rag/llm/cv_model.py @@ -243,6 +243,33 @@ class QWenCV(GptV4): tmp_path = tmp.name video_path = f"file://{tmp_path}" + prompt_ch = """ + 你是一名专业的视频转录助手,能够将视频文件的内容转写为文本,并**精确标记每句话或每个段落对应的时间戳**(开始时间-结束时间)。\n + **任务要求**: + 1.输入是MP4等视频文件,解析带时间戳的文本。 + 2.时间戳格式为 `[HH:MM:SS.mmm]`(毫秒可选),例如 `[00:01:23.456]`。 + 3.时间戳需尽可能贴近实际视频的起止时间,误差不超过1秒。 + 4.如果无法确定具体时间,请根据上下文合理估算。 + 5.最后总结:这段视频的内容是什么?,并用恰当的句子总结这个视频。 + + **示例输出**: + [00:00:00.000] 今天天气真好, + [00:00:02.500] 我们一起去公园散步吧。 + [00:00:05.800] 公园里的花开得非常漂亮。 + 这段视频的内容是关于如何在CREAMS系统中进行楼宇管理集合的编辑或删除操作。视频演示了 ...""" + prompt_en = """ + You are a professional video transcription assistant, capable of transcribing the content of video files into text and **precisely marking the timestamp (start time-end time) corresponding to each sentence or paragraph**. + **Task requirements**: + 1. Input is MP4 or other video files, and parse the text with timestamps. + 2. The timestamp format is `[HH:MM:SS.mmm]` (milliseconds are optional), for example, `[00:01:23.456]`. + 3. The timestamp should be as close as possible to the actual start and end time of the video, with an error not exceeding 1 second. + 4. If the specific time cannot be determined, please make a reasonable estimation based on the context. + 5. Final summary: What is the content of this video? Summarize this video in an appropriate sentence. + + **Example output**: + [00:00:00.000] The weather is really nice today, [00:00:02.500] let's go for a walk in the park together. + [00:00:05.800] The flowers in the park are blooming beautifully. + The content of this video is about how to edit or delete building management collections in the CREAMS system. The video demonstrates ..""" messages = [ { "role": "user", @@ -252,7 +279,7 @@ class QWenCV(GptV4): "fps": 2, }, { - "text": "视频的内容是什么?,并且,请用恰当的句子总结这个视频。" if self.lang.lower() == "chinese" else "What is the content of the video? And please summarize this video in proper sentences.", + "text": prompt_ch if self.lang.lower() == "chinese" else prompt_en, }, ], } diff --git a/api/app/core/rag/llm/sequence2txt_model.py b/api/app/core/rag/llm/sequence2txt_model.py index be4d3649..468dda55 100644 --- a/api/app/core/rag/llm/sequence2txt_model.py +++ b/api/app/core/rag/llm/sequence2txt_model.py @@ -60,6 +60,34 @@ class QWenSeq2txt(Base): from dashscope import MultiModalConversation audio_path = f"file://{audio_path}" + prompt_ch = """ + 你是一名专业的音频转录助手,能够将MP3音频文件的内容转写为文本,并**精确标记每句话或每个段落对应的时间戳**(开始时间-结束时间)。\n + **任务要求**: + 1.输入是MP3,解析带时间戳的文本。 + 2.时间戳格式为 `[HH:MM:SS.mmm]`(毫秒可选),例如 `[00:01:23.456]`。 + 3.时间戳需尽可能贴近实际语音的起止时间,误差不超过1秒。 + 4.如果无法确定具体时间,请根据上下文合理估算。 + 5.最后总结:这段音频在说什么? + + **示例输出**: + [00:00:00.000] 今天天气真好, + [00:00:02.500] 我们一起去公园散步吧。 + [00:00:05.800] 公园里的花开得非常漂亮。 + 这段音频讲述的是一个关于**“吃水不忘挖井人”**的感人故事,主 ...""" + prompt_en = """ + You are a professional audio transcription assistant, capable of transcribing the content of MP3 audio files into text and **precisely marking the timestamps (start time - end time) corresponding to each sentence or paragraph**. + **Task requirements**: + 1. Input is MP3, parse text with timestamps. + 2. The timestamp format is `[HH:MM:SS.mmm]` (milliseconds are optional), for example, `[00:01:23.456]`. + 3. The timestamp should be as close as possible to the actual start and end time of the voice, with an error not exceeding 1 second. + 4. If a specific time cannot be determined, please make a reasonable estimation based on the context. + 5. Final summary: What is this audio talking about? + + **Example Output**: + [00:00:00.000] The weather is really nice today, + [00:00:02.500] let's go for a walk in the park together. + [00:00:05.800] The flowers in the park are blooming beautifully. + This audio tells a touching story about **"Remembering the one who dug the well when drinking water"** ..""" messages = [ { "role": "user", @@ -68,7 +96,7 @@ class QWenSeq2txt(Base): "audio": audio_path }, { - "text": "这段音频在说什么?" if self.lang.lower() == "chinese" else "What is this audio saying?", + "text": prompt_ch if self.lang.lower() == "chinese" else prompt_en, }, ], } diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c048f447..ad03fec1 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -8,6 +8,7 @@ import logging import uuid from typing import Any +from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph from app.core.workflow.graph_builder import GraphBuilder @@ -53,11 +54,11 @@ class WorkflowExecutor: self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) - self.checkpoint_config = { - "configurable": { + self.checkpoint_config = RunnableConfig( + configurable={ "thread_id": uuid.uuid4(), } - } + ) def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: """准备初始状态(注入系统变量和会话变量) @@ -214,13 +215,13 @@ class WorkflowExecutor: return { "status": "completed", "output": final_output, + "variables": result.get("variables", {}), "node_outputs": node_outputs, "messages": result.get("messages", []), "conversation_id": conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, "error": result.get("error"), - "variables": result.get("variables", {}), } def build_graph(self, stream=False) -> CompiledStateGraph: @@ -326,11 +327,10 @@ class WorkflowExecutor: } # 1. 构建图 - graph = self.build_graph(True) + graph = self.build_graph(stream=True) # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - # 3. Execute workflow try: chunk_count = 0 @@ -346,14 +346,16 @@ class WorkflowExecutor: mode, data = event else: # Unexpected format, log and skip - logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}") + logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}" + f"- execution_id: {self.execution_id}") continue if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) chunk_count += 1 event_type = data.get("type", "node_chunk") # "message" or "node_chunk" - logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}") + logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" + f"- execution_id: {self.execution_id}") yield { "event": event_type, # "message" or "node_chunk" "data": { @@ -380,7 +382,8 @@ class WorkflowExecutor: variables_sys = variables.get("sys", {}) conversation_id = input_data.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] Node starts execution: {node_name}") + logger.info(f"[NODE-START] Node starts execution: {node_name} " + f"- execution_id: {self.execution_id}") yield { "event": "node_start", @@ -399,7 +402,8 @@ class WorkflowExecutor: variables_sys = variables.get("sys", {}) conversation_id = input_data.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] Node execution completed: {node_name}") + logger.info(f"[NODE-END] Node execution completed: {node_name} " + f"- execution_id: {self.execution_id}") yield { "event": "node_end", @@ -407,13 +411,15 @@ class WorkflowExecutor: "node_id": node_name, "conversation_id": conversation_id, "execution_id": execution_id, - "timestamp": data.get("timestamp") + "timestamp": data.get("timestamp"), + "state": result.get("node_outputs", {}).get(node_name), } } elif mode == "updates": # Handle state updates - store final state - logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") + logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " + f"- execution_id: {self.execution_id}") # 计算耗时 end_time = datetime.datetime.now() @@ -421,7 +427,7 @@ class WorkflowExecutor: result = graph.get_state(self.checkpoint_config).values logger.info( f"Workflow execution completed (streaming), " - f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s" + f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" ) # 发送 workflow_end 事件 @@ -449,7 +455,8 @@ class WorkflowExecutor: } } - def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: + @staticmethod + def _extract_final_output(node_outputs: dict[str, Any]) -> str | None: """从节点输出中提取最终输出 优先级: @@ -473,7 +480,8 @@ class WorkflowExecutor: return None - def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None: + @staticmethod + def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None: """聚合所有节点的 token 使用情况 Args: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 72fd0bb5..b31213d8 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: list[dict[str, str]] + messages: Annotated[list[dict[str, str]], lambda x, y: y] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index 4ae8e118..66c3a700 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -21,6 +21,7 @@ class IterationRuntime: optional parallel execution, flattening of output, and loop control via the workflow state. """ + def __init__( self, graph: CompiledStateGraph, @@ -87,6 +88,7 @@ class IterationRuntime: self.result.append(output) if not result["looping"]: self.looping = False + return result def _create_iteration_tasks(self, array_obj, idx): """ @@ -124,7 +126,7 @@ class IterationRuntime: array_obj = VariablePool(self.state).get(input_expression) if not isinstance(array_obj, list): raise RuntimeError("Cannot iterate over a non-list variable") - + child_state = [] idx = 0 if self.typed_config.parallel: # Execute iterations in parallel batches @@ -132,15 +134,14 @@ class IterationRuntime: tasks = self._create_iteration_tasks(array_obj, idx) logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count - await asyncio.gather(*tasks) - logger.info(f"Iteration node {self.node_id}: execution completed") - return self.result + child_state.extend(await asyncio.gather(*tasks)) else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + child_state.append(result) output = VariablePool(result).get(self.output_value) if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) @@ -150,5 +151,8 @@ class IterationRuntime: self.looping = False idx += 1 - logger.info(f"Iteration node {self.node_id}: execution completed") - return self.result + logger.info(f"Iteration node {self.node_id}: execution completed") + return { + "output": self.result, + "__child_state": child_state + } diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 2e2ab4fb..38d4b21c 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -67,7 +67,9 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } self.state["node_outputs"][self.node_id] = { @@ -76,7 +78,9 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } loopstate = WorkflowState( @@ -171,10 +175,11 @@ class LoopRuntime: """ loopstate = self._init_loop_state() loop_time = self.typed_config.max_loop + child_state = [] while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0: logger.info(f"loop node {self.node_id}: running") - await self.graph.ainvoke(loopstate) + child_state.append(await self.graph.ainvoke(loopstate)) loop_time -= 1 logger.info(f"loop node {self.node_id}: execution completed") - return loopstate["runtime_vars"][self.node_id] + return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 221ca079..997135f3 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -10,9 +10,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.db import get_db_read from app.models import knowledge_model, knowledgeshare_model, ModelType -from app.repositories import knowledge_repository +from app.repositories import knowledge_repository, knowledgeshare_repository from app.schemas.chunk_schema import RetrieveType -from app.services import knowledge_service, knowledgeshare_service from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) @@ -96,7 +95,7 @@ class KnowledgeRetrievalNode(BaseNode): filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) - share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + share_ids = knowledge_repository.get_chunked_knowledgeids( db=db, filters=filters ) @@ -105,7 +104,7 @@ class KnowledgeRetrievalNode(BaseNode): filters = [ knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids) ] - items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( db=db, filters=filters ) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index f65d5879..265724f3 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -66,7 +66,7 @@ class LLMNodeConfig(BaseNodeConfig): ) memory: MemoryWindowSetting = Field( - ..., + default_factory=MemoryWindowSetting, description="对话上下文窗口" ) diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index e25bd35d..a74e0b60 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -85,6 +85,7 @@ class LLMNode(BaseNode): """ # 1. 处理消息格式(优先使用 messages) + self.typed_config = LLMNodeConfig(**self.config) messages_config = self.typed_config.messages if messages_config: @@ -167,7 +168,7 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - self.typed_config = LLMNodeConfig(**self.config) + # self.typed_config = LLMNodeConfig(**self.config) llm, prompt_or_messages = self._prepare_llm(state, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") @@ -269,12 +270,16 @@ class LLMNode(BaseNode): chunk_count = 0 # 调用 LLM(流式,支持字符串或消息列表) - async for chunk in llm.astream(prompt_or_messages): + last_meta_data = {} + async for chunk in llm.astream(prompt_or_messages, stream_usage=True): # 提取内容 if hasattr(chunk, 'content'): content = chunk.content else: content = str(chunk) + if hasattr(chunk, 'response_metadata'): + if chunk.response_metadata: + last_meta_data = chunk.response_metadata # 只有当内容不为空时才处理 if content: @@ -288,13 +293,10 @@ class LLMNode(BaseNode): logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") # 构建完整的 AIMessage(包含元数据) - if isinstance(last_chunk, AIMessage): - final_message = AIMessage( - content=full_response, - response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - ) - else: - final_message = AIMessage(content=full_response) + final_message = AIMessage( + content=full_response, + response_metadata=last_meta_data + ) # yield 完成标记 yield {"__final__": True, "result": final_message} diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index f45991cd..81cc6ead 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -27,8 +27,6 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel -from .emotion_suggestions_cache_model import EmotionSuggestionsCache -from .implicit_memory_cache_model import ImplicitMemoryCache __all__ = [ "Tenants", @@ -79,6 +77,4 @@ __all__ = [ "AuthType", "ExecutionStatus", "MemoryPerceptualModel", - "EmotionSuggestionsCache", - "ImplicitMemoryCache" ] diff --git a/api/app/models/emotion_suggestions_cache_model.py b/api/app/models/emotion_suggestions_cache_model.py deleted file mode 100644 index 9b32f424..00000000 --- a/api/app/models/emotion_suggestions_cache_model.py +++ /dev/null @@ -1,24 +0,0 @@ -"""情绪建议缓存模型""" - -import uuid -import datetime -from sqlalchemy import Column, String, Text, Integer, DateTime, JSON -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - - -class EmotionSuggestionsCache(Base): - """情绪建议缓存表 - - 用于缓存个性化情绪建议,减少 LLM 调用成本,提升响应速度。 - """ - __tablename__ = "emotion_suggestions_cache" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - end_user_id = Column(String(255), nullable=False, unique=True, index=True, comment="终端用户ID(组ID)") - health_summary = Column(Text, nullable=False, comment="健康状态摘要") - suggestions = Column(JSON, nullable=False, comment="建议列表(JSON格式)") - generated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, comment="生成时间") - expires_at = Column(DateTime, nullable=True, comment="过期时间") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) diff --git a/api/app/models/implicit_memory_cache_model.py b/api/app/models/implicit_memory_cache_model.py deleted file mode 100644 index 32defbab..00000000 --- a/api/app/models/implicit_memory_cache_model.py +++ /dev/null @@ -1,27 +0,0 @@ -"""隐性记忆缓存模型""" - -import uuid -import datetime -from sqlalchemy import Column, String, Integer, DateTime, JSON -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - - -class ImplicitMemoryCache(Base): - """隐性记忆缓存表 - - 用于缓存用户的完整隐性记忆画像,包括偏好标签、四维画像、兴趣领域和行为习惯。 - 减少 LLM 调用成本,提升响应速度。 - """ - __tablename__ = "implicit_memory_cache" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - end_user_id = Column(String(255), nullable=False, unique=True, index=True, comment="终端用户ID") - preferences = Column(JSON, nullable=False, comment="偏好标签列表(JSON格式)") - portrait = Column(JSON, nullable=False, comment="四维画像对象(JSON格式)") - interest_areas = Column(JSON, nullable=False, comment="兴趣领域分布对象(JSON格式)") - habits = Column(JSON, nullable=False, comment="行为习惯列表(JSON格式)") - generated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, comment="生成时间") - expires_at = Column(DateTime, nullable=True, comment="过期时间") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) diff --git a/api/app/models/workflow_model.py b/api/app/models/workflow_model.py index d599f717..4f9ffe68 100644 --- a/api/app/models/workflow_model.py +++ b/api/app/models/workflow_model.py @@ -75,6 +75,14 @@ class WorkflowExecution(Base): nullable=False, index=True ) + + release_id = Column( + UUID(as_uuid=True), + ForeignKey("app_releases.id", ondelete="CASCADE"), + nullable=True, + index=True + ) + app_id = Column( UUID(as_uuid=True), ForeignKey("apps.id", ondelete="CASCADE"), diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index 135c0063..d26058b2 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -10,7 +10,7 @@ Classes: import uuid from typing import Dict, List, Optional, Tuple - +from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.data_config_model import DataConfig from app.schemas.memory_storage_schema import ( @@ -20,7 +20,7 @@ from app.schemas.memory_storage_schema import ( ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc +from sqlalchemy import desc, select from sqlalchemy.orm import Session # 获取数据库专用日志器 @@ -136,72 +136,88 @@ class DataConfigRepository: id: m.id } AS targetNode """ - - # ==================== SQLAlchemy ORM 数据库操作方法 ==================== @staticmethod - def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]: + def update_reflection_config( + db: Session, + config_id: int, + enable_self_reflexion: bool, + iteration_period: str, + reflexion_range: str, + baseline: str, + reflection_model_id: str, + memory_verify: bool, + quality_assessment: bool + ) -> DataConfig: """构建反思配置更新语句(SQLAlchemy text() 命名参数) Args: + quality_assessment: + memory_verify: + reflection_model_id: + baseline: + reflexion_range: + iteration_period: + enable_self_reflexion: + db: database object config_id: 配置ID - **kwargs: 反思配置参数 Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) + Data Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config_obj = db.scalars(stmt).first() + if not data_config_obj: + raise BusinessException + data_config_obj.enable_self_reflexion = enable_self_reflexion + data_config_obj.iteration_period = iteration_period + data_config_obj.reflexion_range = reflexion_range + data_config_obj.baseline = baseline + data_config_obj.reflection_model_id = reflection_model_id + data_config_obj.memory_verify = memory_verify + data_config_obj.quality_assessment = quality_assessment - key_where = "config_id = :config_id" - set_fields: List[str] = [] - params: Dict = { - "config_id": config_id, - } - - # 反思配置字段映射 - mapping = { - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - "reflection_model_id": "reflection_model_id", - "memory_verify": "memory_verify", - "quality_assessment": "quality_assessment", - } - - for api_field, db_col in mapping.items(): - if api_field in kwargs and kwargs[api_field] is not None: - set_fields.append(f"{db_col} = :{api_field}") - params[api_field] = kwargs[api_field] - - if not set_fields: - raise ValueError("No fields to update") - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params + return data_config_obj @staticmethod - def build_select_reflection(config_id: int) -> Tuple[str, Dict]: + def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: + db: database object config_id: 配置ID Returns: Tuple[str, Dict]: (SQL查询字符串, 参数字典) """ db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config + @staticmethod + def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig: + """构建查询所有配置的语句(SQLAlchemy text() 命名参数) + + Args: + db: database object + workspace_id: 工作空间ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") + + stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config - query = ( - f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, " - f"reflection_model_id, memory_verify, quality_assessment, user_id " - f"FROM {TABLE_NAME} WHERE config_id = :config_id" - ) - params = {"config_id": config_id} - return query, params @staticmethod def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: diff --git a/api/app/repositories/emotion_suggestions_cache_repository.py b/api/app/repositories/emotion_suggestions_cache_repository.py deleted file mode 100644 index 1c0430d5..00000000 --- a/api/app/repositories/emotion_suggestions_cache_repository.py +++ /dev/null @@ -1,163 +0,0 @@ -"""情绪建议缓存仓储层""" - -from sqlalchemy.orm import Session -from typing import Optional, Dict, Any -import datetime - -from app.models.emotion_suggestions_cache_model import EmotionSuggestionsCache -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class EmotionSuggestionsCacheRepository: - """情绪建议缓存仓储类""" - - def __init__(self, db: Session): - self.db = db - - def get_by_end_user_id(self, end_user_id: str) -> Optional[EmotionSuggestionsCache]: - """根据终端用户ID获取缓存 - - Args: - end_user_id: 终端用户ID(组ID) - - Returns: - 缓存记录,如果不存在返回 None - """ - try: - cache = ( - self.db.query(EmotionSuggestionsCache) - .filter(EmotionSuggestionsCache.end_user_id == end_user_id) - .first() - ) - if cache: - db_logger.info(f"成功获取用户 {end_user_id} 的情绪建议缓存") - else: - db_logger.info(f"用户 {end_user_id} 的情绪建议缓存不存在") - return cache - except Exception as e: - db_logger.error(f"获取用户 {end_user_id} 的情绪建议缓存失败: {str(e)}") - raise - - def create_or_update( - self, - end_user_id: str, - health_summary: str, - suggestions: list, - expires_hours: int = 24 - ) -> EmotionSuggestionsCache: - """创建或更新缓存 - - Args: - end_user_id: 终端用户ID(组ID) - health_summary: 健康状态摘要 - suggestions: 建议列表 - expires_hours: 过期时间(小时),默认24小时 - - Returns: - 缓存记录 - """ - try: - # 查找现有记录 - cache = self.get_by_end_user_id(end_user_id) - - now = datetime.datetime.now() - expires_at = now + datetime.timedelta(hours=expires_hours) - - if cache: - # 更新现有记录 - cache.health_summary = health_summary - cache.suggestions = suggestions - cache.generated_at = now - cache.expires_at = expires_at - cache.updated_at = now - db_logger.info(f"更新用户 {end_user_id} 的情绪建议缓存") - else: - # 创建新记录 - cache = EmotionSuggestionsCache( - end_user_id=end_user_id, - health_summary=health_summary, - suggestions=suggestions, - generated_at=now, - expires_at=expires_at, - created_at=now, - updated_at=now - ) - self.db.add(cache) - db_logger.info(f"创建用户 {end_user_id} 的情绪建议缓存") - - self.db.commit() - self.db.refresh(cache) - return cache - except Exception as e: - self.db.rollback() - db_logger.error(f"创建或更新用户 {end_user_id} 的情绪建议缓存失败: {str(e)}") - raise - - def delete_by_end_user_id(self, end_user_id: str) -> bool: - """删除缓存 - - Args: - end_user_id: 终端用户ID(组ID) - - Returns: - 是否删除成功 - """ - try: - cache = self.get_by_end_user_id(end_user_id) - if cache: - self.db.delete(cache) - self.db.commit() - db_logger.info(f"删除用户 {end_user_id} 的情绪建议缓存") - return True - return False - except Exception as e: - self.db.rollback() - db_logger.error(f"删除用户 {end_user_id} 的情绪建议缓存失败: {str(e)}") - raise - - @staticmethod - def is_expired(cache: EmotionSuggestionsCache) -> bool: - """检查缓存是否过期 - - Args: - cache: 缓存记录 - - Returns: - 是否过期 - """ - if cache.expires_at is None: - return False - return datetime.datetime.now() > cache.expires_at - - -# 便捷函数 -def get_cache_by_end_user_id(db: Session, end_user_id: str) -> Optional[EmotionSuggestionsCache]: - """根据终端用户ID获取缓存""" - repo = EmotionSuggestionsCacheRepository(db) - return repo.get_by_end_user_id(end_user_id) - - -def create_or_update_cache( - db: Session, - end_user_id: str, - health_summary: str, - suggestions: list, - expires_hours: int = 24 -) -> EmotionSuggestionsCache: - """创建或更新缓存""" - repo = EmotionSuggestionsCacheRepository(db) - return repo.create_or_update(end_user_id, health_summary, suggestions, expires_hours) - - -def delete_cache_by_end_user_id(db: Session, end_user_id: str) -> bool: - """删除缓存""" - repo = EmotionSuggestionsCacheRepository(db) - return repo.delete_by_end_user_id(end_user_id) - - -def is_cache_expired(cache: EmotionSuggestionsCache) -> bool: - """检查缓存是否过期""" - return EmotionSuggestionsCacheRepository.is_expired(cache) diff --git a/api/app/repositories/implicit_memory_cache_repository.py b/api/app/repositories/implicit_memory_cache_repository.py deleted file mode 100644 index 65356980..00000000 --- a/api/app/repositories/implicit_memory_cache_repository.py +++ /dev/null @@ -1,175 +0,0 @@ -"""隐性记忆缓存仓储层""" - -from sqlalchemy.orm import Session -from typing import Optional, Dict, Any -import datetime - -from app.models.implicit_memory_cache_model import ImplicitMemoryCache -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class ImplicitMemoryCacheRepository: - """隐性记忆缓存仓储类""" - - def __init__(self, db: Session): - self.db = db - - def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitMemoryCache]: - """根据终端用户ID获取缓存 - - Args: - end_user_id: 终端用户ID - - Returns: - 缓存记录,如果不存在返回 None - """ - try: - cache = ( - self.db.query(ImplicitMemoryCache) - .filter(ImplicitMemoryCache.end_user_id == end_user_id) - .first() - ) - if cache: - db_logger.info(f"成功获取用户 {end_user_id} 的隐性记忆缓存") - else: - db_logger.info(f"用户 {end_user_id} 的隐性记忆缓存不存在") - return cache - except Exception as e: - db_logger.error(f"获取用户 {end_user_id} 的隐性记忆缓存失败: {str(e)}") - raise - - def create_or_update( - self, - end_user_id: str, - preferences: list, - portrait: dict, - interest_areas: dict, - habits: list, - expires_hours: int = 168 # 默认7天 - ) -> ImplicitMemoryCache: - """创建或更新缓存 - - Args: - end_user_id: 终端用户ID - preferences: 偏好标签列表 - portrait: 四维画像对象 - interest_areas: 兴趣领域分布对象 - habits: 行为习惯列表 - expires_hours: 过期时间(小时),默认168小时(7天) - - Returns: - 缓存记录 - """ - try: - # 查找现有记录 - cache = self.get_by_end_user_id(end_user_id) - - now = datetime.datetime.now() - expires_at = now + datetime.timedelta(hours=expires_hours) - - if cache: - # 更新现有记录 - cache.preferences = preferences - cache.portrait = portrait - cache.interest_areas = interest_areas - cache.habits = habits - cache.generated_at = now - cache.expires_at = expires_at - cache.updated_at = now - db_logger.info(f"更新用户 {end_user_id} 的隐性记忆缓存") - else: - # 创建新记录 - cache = ImplicitMemoryCache( - end_user_id=end_user_id, - preferences=preferences, - portrait=portrait, - interest_areas=interest_areas, - habits=habits, - generated_at=now, - expires_at=expires_at, - created_at=now, - updated_at=now - ) - self.db.add(cache) - db_logger.info(f"创建用户 {end_user_id} 的隐性记忆缓存") - - self.db.commit() - self.db.refresh(cache) - return cache - except Exception as e: - self.db.rollback() - db_logger.error(f"创建或更新用户 {end_user_id} 的隐性记忆缓存失败: {str(e)}") - raise - - def delete_by_end_user_id(self, end_user_id: str) -> bool: - """删除缓存 - - Args: - end_user_id: 终端用户ID - - Returns: - 是否删除成功 - """ - try: - cache = self.get_by_end_user_id(end_user_id) - if cache: - self.db.delete(cache) - self.db.commit() - db_logger.info(f"删除用户 {end_user_id} 的隐性记忆缓存") - return True - return False - except Exception as e: - self.db.rollback() - db_logger.error(f"删除用户 {end_user_id} 的隐性记忆缓存失败: {str(e)}") - raise - - @staticmethod - def is_expired(cache: ImplicitMemoryCache) -> bool: - """检查缓存是否过期 - - Args: - cache: 缓存记录 - - Returns: - 是否过期 - """ - if cache.expires_at is None: - return False - return datetime.datetime.now() > cache.expires_at - - -# 便捷函数 -def get_cache_by_end_user_id(db: Session, end_user_id: str) -> Optional[ImplicitMemoryCache]: - """根据终端用户ID获取缓存""" - repo = ImplicitMemoryCacheRepository(db) - return repo.get_by_end_user_id(end_user_id) - - -def create_or_update_cache( - db: Session, - end_user_id: str, - preferences: list, - portrait: dict, - interest_areas: dict, - habits: list, - expires_hours: int = 168 -) -> ImplicitMemoryCache: - """创建或更新缓存""" - repo = ImplicitMemoryCacheRepository(db) - return repo.create_or_update( - end_user_id, preferences, portrait, interest_areas, habits, expires_hours - ) - - -def delete_cache_by_end_user_id(db: Session, end_user_id: str) -> bool: - """删除缓存""" - repo = ImplicitMemoryCacheRepository(db) - return repo.delete_by_end_user_id(end_user_id) - - -def is_cache_expired(cache: ImplicitMemoryCache) -> bool: - """检查缓存是否过期""" - return ImplicitMemoryCacheRepository.is_expired(cache) diff --git a/api/app/repositories/memory_increment_repository.py b/api/app/repositories/memory_increment_repository.py index 37396fbd..f3a56622 100644 --- a/api/app/repositories/memory_increment_repository.py +++ b/api/app/repositories/memory_increment_repository.py @@ -25,7 +25,7 @@ class MemoryIncrementRepository: MemoryIncrement, func.row_number().over( partition_by=func.date(MemoryIncrement.created_at), # 按日期分区 - order_by=MemoryIncrement.created_at.desc() # 按时间戳升序排序 + order_by=MemoryIncrement.created_at.desc() # 按时间戳降序排序,取每天最新的 ).label('row_num') ) .filter(MemoryIncrement.workspace_id == workspace_id) @@ -34,14 +34,24 @@ class MemoryIncrementRepository: memory_increment_alias = aliased(MemoryIncrement, subquery) - memory_increments = ( + # 先取最近的limit条记录的子查询 + recent_records_subquery = ( self.db.query(memory_increment_alias) .filter(subquery.c.row_num == 1) # 只取每个日期的第一条(最新的) - .order_by(memory_increment_alias.created_at.asc()) # 按时间戳降序排序 + .order_by(memory_increment_alias.created_at.desc()) # 按时间戳降序排序,取最近的 .limit(limit) + .subquery() + ) + + # 在外层按升序排列(从旧到新) + recent_alias = aliased(MemoryIncrement, recent_records_subquery) + memory_increments = ( + self.db.query(recent_alias) + .order_by(recent_alias.created_at.asc()) # 按时间戳升序排序 .all() ) - db_logger.info(f"成功查询工作空间 {workspace_id} 下的内存增量") + + db_logger.info(f"成功查询工作空间 {workspace_id} 下的内存增量,返回最近 {len(memory_increments)} 条记录") return memory_increments except Exception as e: db_logger.error(f"查询工作空间 {workspace_id} 下内存增量时出错: {str(e)}") diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c91c2e80..cd3cbed7 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -837,12 +837,14 @@ neo4j_query_part = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, @@ -855,12 +857,14 @@ neo4j_query_all = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 80756793..0b6a27c6 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, List, Optional import asyncio import logging +from typing import Any, Dict, List, Optional + +from app.repositories.neo4j.cypher_queries import ( + CHUNK_EMBEDDING_SEARCH, + ENTITY_EMBEDDING_SEARCH, + MEMORY_SUMMARY_EMBEDDING_SEARCH, + SEARCH_CHUNK_BY_CHUNK_ID, + SEARCH_CHUNKS_BY_CONTENT, + SEARCH_DIALOGUE_BY_DIALOG_ID, + SEARCH_ENTITIES_BY_NAME, + SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + SEARCH_STATEMENTS_BY_CREATED_AT, + SEARCH_STATEMENTS_BY_KEYWORD, + SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, + SEARCH_STATEMENTS_BY_TEMPORAL, + SEARCH_STATEMENTS_BY_VALID_AT, + SEARCH_STATEMENTS_G_CREATED_AT, + SEARCH_STATEMENTS_G_VALID_AT, + SEARCH_STATEMENTS_L_CREATED_AT, + SEARCH_STATEMENTS_L_VALID_AT, + STATEMENT_EMBEDDING_SEARCH, +) # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.cypher_queries import ( - SEARCH_STATEMENTS_BY_KEYWORD, - SEARCH_ENTITIES_BY_NAME, - SEARCH_CHUNKS_BY_CONTENT, - STATEMENT_EMBEDDING_SEARCH, - CHUNK_EMBEDDING_SEARCH, - ENTITY_EMBEDDING_SEARCH, - SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - MEMORY_SUMMARY_EMBEDDING_SEARCH, - SEARCH_STATEMENTS_BY_TEMPORAL, - SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - SEARCH_DIALOGUE_BY_DIALOG_ID, - SEARCH_CHUNK_BY_CHUNK_ID, - SEARCH_STATEMENTS_BY_CREATED_AT, - SEARCH_STATEMENTS_BY_VALID_AT, - SEARCH_STATEMENTS_G_CREATED_AT, - SEARCH_STATEMENTS_L_CREATED_AT, - SEARCH_STATEMENTS_G_VALID_AT, - SEARCH_STATEMENTS_L_VALID_AT, -) logger = logging.getLogger(__name__) @@ -55,8 +56,12 @@ async def _update_activation_values_batch( return [] # 延迟导入以避免循环依赖 - from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager - from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( + AccessHistoryManager, + ) + from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( + ACTRCalculator, + ) # 创建计算器和管理器实例 actr_calculator = ACTRCalculator() @@ -292,6 +297,13 @@ async def search_graph( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) results = await _update_search_results_activation( connector=connector, @@ -397,6 +409,13 @@ async def search_graph_by_embedding( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) update_start = time.time() results = await _update_search_results_activation( diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py index 73b44396..753ae256 100644 --- a/api/app/repositories/neo4j/neo4j_update.py +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -11,22 +11,28 @@ async def update_neo4j_data(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"e.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(e) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"e.{key} = ${param_name}") params[param_name] = value @@ -76,22 +82,28 @@ async def update_neo4j_data_edge(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"r.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(r) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"r.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"r.{key} = ${param_name}") params[param_name] = value @@ -242,7 +254,16 @@ async def neo4j_data(solved_data): if key=='expired_at': updat_expired_at[key] = values[1] - elif key == 'statement_id': + elif key == 'id': + ori_edge[key] = values + updata_edge[key] = values + + ori_entity[key] = values + updata_entity[key] = values + + ori_expired_at[key] = values + elif key == 'rel_id': + key='id' ori_edge[key] = values updata_edge[key] = values diff --git a/api/app/schemas/emotion_schema.py b/api/app/schemas/emotion_schema.py index 37e9a2e3..5175fed1 100644 --- a/api/app/schemas/emotion_schema.py +++ b/api/app/schemas/emotion_schema.py @@ -34,5 +34,4 @@ class EmotionSuggestionsRequest(BaseModel): class EmotionGenerateSuggestionsRequest(BaseModel): """生成个性化情绪建议请求""" - group_id: str = Field(..., description="组ID") - config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)") + end_user_id: str = Field(..., description="终端用户ID") diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index ecb1570f..d17a9f2c 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -35,10 +35,10 @@ class BaseDataSchema(BaseModel): expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") description: Optional[str] = Field(None, description="The description of the data entry.") - # 新增字段以匹配实际输入数据 - entity1_name: str = Field(..., description="The first entity name.") + # 新增字段以匹配实际输入数据 - 改为可选以支持resolved_memory场景 + entity1_name: Optional[str] = Field(None, description="The first entity name.") entity2_name: Optional[str] = Field(None, description="The second entity name.") - statement_id: str = Field(..., description="The statement identifier.") + statement_id: Optional[str] = Field(None, description="The statement identifier.") # 新增字段 - 设为可选以保持向后兼容性 predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.") relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.") @@ -108,13 +108,13 @@ class ChangeRecordSchema(BaseModel): """Schema for individual change records 字段值格式说明: - - id 和 statement_id: 字符串或 None + - id: 字符串,表示修改字段对应的记录ID - 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构 - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( ..., - description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" + description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) class ResolvedSchema(BaseModel): diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index bc2d6ca3..c0a66e03 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -527,6 +527,7 @@ class AppChatService: conversation_id: uuid.UUID, config: WorkflowConfig, app_id: uuid.UUID, + release_id: uuid.UUID, workspace_id: uuid.UUID, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, @@ -549,6 +550,7 @@ class AppChatService: payload=payload, config=config, workspace_id=workspace_id, + release_id=release_id, ) async def workflow_chat_stream( @@ -557,6 +559,7 @@ class AppChatService: conversation_id: uuid.UUID, config: WorkflowConfig, app_id: uuid.UUID, + release_id: uuid.UUID, workspace_id: uuid.UUID, user_id: str = None, variables: Optional[Dict[str, Any]] = None, @@ -565,7 +568,7 @@ class AppChatService: storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[dict, None]: """聊天(流式)""" workflow_service = WorkflowService(self.db) payload = DraftRunRequest( @@ -580,6 +583,7 @@ class AppChatService: payload=payload, config=config, workspace_id=workspace_id, + release_id=release_id ): yield event diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 6d5204f8..2ac9ac05 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -129,7 +129,7 @@ class AppService: Raises: ResourceNotFoundException: 当应用不存在时 """ - app = get_apps_by_id(self.db,app_id) + app = get_apps_by_id(self.db, app_id) if not app: logger.warning("应用不存在", extra={"app_id": str(app_id)}) raise ResourceNotFoundException("应用", str(app_id)) @@ -227,7 +227,6 @@ class AppService: if not model_api_key: raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) - # 3. 检查子 Agent 配置 if not multi_agent_config.sub_agents or len(multi_agent_config.sub_agents) == 0: raise BusinessException( @@ -281,10 +280,10 @@ class AppService: ) def _create_agent_config( - self, - app_id: uuid.UUID, - config_data: app_schema.AgentConfigCreate, - now: datetime.datetime + self, + app_id: uuid.UUID, + config_data: app_schema.AgentConfigCreate, + now: datetime.datetime ) -> None: """创建 Agent 配置(内部方法) @@ -313,10 +312,10 @@ class AppService: logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)}) def _create_multi_agent_config( - self, - app_id: uuid.UUID, - config_data: Dict[str, Any], - now: datetime.datetime + self, + app_id: uuid.UUID, + config_data: Dict[str, Any], + now: datetime.datetime ) -> None: """创建多 Agent 配置(内部方法) @@ -411,9 +410,9 @@ class AppService: return 1 if max_ver is None else int(max_ver) + 1 def _convert_to_schema( - self, - app: App, - current_workspace_id: uuid.UUID + self, + app: App, + current_workspace_id: uuid.UUID ) -> app_schema.App: """将 App 模型转换为 Schema,并设置 is_shared 字段 @@ -447,9 +446,9 @@ class AppService: # ==================== 应用管理 ==================== def get_app( - self, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> App: """获取应用详情 @@ -469,11 +468,11 @@ class AppService: return app def create_app( - self, - *, - user_id: uuid.UUID, - workspace_id: uuid.UUID, - data: app_schema.AppCreate + self, + *, + user_id: uuid.UUID, + workspace_id: uuid.UUID, + data: app_schema.AppCreate ) -> App: """创建应用 @@ -535,11 +534,11 @@ class AppService: raise BusinessException(f"应用创建失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) def update_app( - self, - *, - app_id: uuid.UUID, - data: app_schema.AppUpdate, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + data: app_schema.AppUpdate, + workspace_id: Optional[uuid.UUID] = None ) -> App: """更新应用基本信息 @@ -578,10 +577,10 @@ class AppService: return app def delete_app( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> None: """删除应用 @@ -612,12 +611,12 @@ class AppService: ) def copy_app( - self, - *, - app_id: uuid.UUID, - user_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None, - new_name: Optional[str] = None + self, + *, + app_id: uuid.UUID, + user_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None, + new_name: Optional[str] = None ) -> App: """复制应用(包括基础信息和配置) @@ -716,16 +715,16 @@ class AppService: raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) def list_apps( - self, - *, - workspace_id: uuid.UUID, - type: Optional[str] = None, - visibility: Optional[str] = None, - status: Optional[str] = None, - search: Optional[str] = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, + self, + *, + workspace_id: uuid.UUID, + type: Optional[str] = None, + visibility: Optional[str] = None, + status: Optional[str] = None, + search: Optional[str] = None, + include_shared: bool = True, + page: int = 1, + pagesize: int = 10, ) -> Tuple[List[App], int]: """列出工作空间中的应用(分页) @@ -759,8 +758,7 @@ class AppService: ) # 构建查询条件 - filters = [] - filters.append(App.is_active == True) + filters = [App.is_active == True] if type: filters.append(App.type == type) if visibility: @@ -813,9 +811,9 @@ class AppService: return items, int(total) def get_apps_by_ids( - self, - app_ids: List[str], - workspace_id: uuid.UUID + self, + app_ids: List[str], + workspace_id: uuid.UUID ) -> List[App]: """根据ID列表获取应用 @@ -846,11 +844,11 @@ class AppService: # ==================== Agent 配置管理 ==================== def update_agent_config( - self, - *, - app_id: uuid.UUID, - data: app_schema.AgentConfigUpdate, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + data: app_schema.AgentConfigUpdate, + workspace_id: Optional[uuid.UUID] = None ) -> AgentConfig: """更新 Agent 配置 @@ -875,7 +873,8 @@ class AppService: self._validate_workspace_access(app, workspace_id) - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active==True).order_by(AgentConfig.updated_at.desc()) + stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( + AgentConfig.updated_at.desc()) agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first() now = datetime.datetime.now() @@ -918,10 +917,10 @@ class AppService: return agent_cfg def get_agent_config( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> AgentConfig: """获取 Agent 配置 @@ -948,7 +947,12 @@ class AppService: # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) + stmt = select(AgentConfig).where( + AgentConfig.app_id == app_id, + AgentConfig.is_active.is_(True) + ).order_by( + AgentConfig.updated_at.desc() + ) config = self.db.scalars(stmt).first() if config: @@ -1166,13 +1170,13 @@ class AppService: # ==================== 应用发布管理 ==================== def publish( - self, - *, - app_id: uuid.UUID, - publisher_id: uuid.UUID, - version_name: str, - workspace_id: Optional[uuid.UUID] = None, - release_notes: Optional[str] = None + self, + *, + app_id: uuid.UUID, + publisher_id: uuid.UUID, + version_name: str, + workspace_id: Optional[uuid.UUID] = None, + release_notes: Optional[str] = None ) -> AppRelease: """发布应用(创建不可变快照) @@ -1200,7 +1204,8 @@ class AppService: default_model_config_id = None if app.type == AppType.AGENT: - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) + stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( + AgentConfig.updated_at.desc()) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING) @@ -1236,8 +1241,7 @@ class AppService: default_model_config_id = multi_agent_cfg.default_model_config_id # 4. 构建配置快照 - - + config = { "model_parameters": model_parameters_to_dict(multi_agent_cfg.model_parameters), "master_agent_id": str(multi_agent_cfg.master_agent_id), @@ -1264,6 +1268,7 @@ class AppService: raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING) config = { + "id": str(workflow_cfg.id), "nodes": workflow_cfg.nodes, "edges": workflow_cfg.edges, "variables": workflow_cfg.variables, @@ -1285,7 +1290,7 @@ class AppService: id=uuid.uuid4(), app_id=app_id, version=version, - version_name = version_name, + version_name=version_name, release_notes=release_notes, name=app.name, description=app.description, @@ -1319,10 +1324,10 @@ class AppService: return release def get_current_release( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> Optional[AppRelease]: """获取当前发布版本 @@ -1349,10 +1354,10 @@ class AppService: return self.db.get(AppRelease, app.current_release_id) def list_releases( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> List[AppRelease]: """列出应用的所有发布版本(倒序) @@ -1381,11 +1386,11 @@ class AppService: return list(self.db.scalars(stmt).all()) def rollback( - self, - *, - app_id: uuid.UUID, - version: int, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + version: int, + workspace_id: Optional[uuid.UUID] = None ) -> AppRelease: """回滚到指定版本 @@ -1434,12 +1439,12 @@ class AppService: # ==================== 应用分享功能 ==================== def share_app( - self, - *, - app_id: uuid.UUID, - target_workspace_ids: List[uuid.UUID], - user_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + target_workspace_ids: List[uuid.UUID], + user_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> AppShare: """分享应用到其他工作空间 @@ -1457,7 +1462,6 @@ class AppService: BusinessException: 当应用不在指定工作空间或目标工作空间无效时 """ - logger.info( "分享应用", extra={ @@ -1536,11 +1540,11 @@ class AppService: return shares def unshare_app( - self, - *, - app_id: uuid.UUID, - target_workspace_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + target_workspace_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> None: """取消应用分享 @@ -1594,10 +1598,10 @@ class AppService: ) def list_app_shares( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> List[AppShare]: """列出应用的所有分享记录 @@ -1637,14 +1641,14 @@ class AppService: # ==================== 试运行功能 ==================== async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + message: str, + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """试运行 Agent(使用当前草稿配置) @@ -1736,14 +1740,14 @@ class AppService: return result async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None + self, + *, + app_id: uuid.UUID, + message: str, + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None ): """试运行 Agent(流式返回) @@ -1794,30 +1798,30 @@ class AppService: # 4. 调用流式试运行服务 draft_service = DraftRunService(self.db) async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables + agent_config=agent_cfg, + model_config=model_config, + message=message, + workspace_id=workspace_id, + conversation_id=conversation_id, + user_id=user_id, + variables=variables ): yield event # ==================== 多模型对比试运行 ==================== async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 + self, + *, + app_id: uuid.UUID, + message: str, + models: List[app_schema.ModelCompareItem], + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None, + parallel: bool = True, + timeout: int = 60 ) -> Dict[str, Any]: """多模型对比试运行 @@ -1907,17 +1911,17 @@ class AppService: return result async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 + self, + *, + app_id: uuid.UUID, + message: str, + models: List[app_schema.ModelCompareItem], + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None, + parallel: bool = True, + timeout: int = 60 ): """多模型对比试运行(流式返回) @@ -1982,15 +1986,15 @@ class AppService: # 4. 调用 DraftRunService 的流式对比方法 draft_service = DraftRunService(self.db) async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout + agent_config=agent_cfg, + models=model_configs, + message=message, + workspace_id=workspace_id, + conversation_id=conversation_id, + user_id=user_id, + variables=variables, + parallel=parallel, + timeout=timeout ): yield event @@ -2009,7 +2013,8 @@ def create_app(db: Session, *, user_id: uuid.UUID, workspace_id: uuid.UUID, data return service.create_app(user_id=user_id, workspace_id=workspace_id, data=data) -def update_app(db: Session, *, app_id: uuid.UUID, data: app_schema.AppUpdate, workspace_id: uuid.UUID | None = None) -> App: +def update_app(db: Session, *, app_id: uuid.UUID, data: app_schema.AppUpdate, + workspace_id: uuid.UUID | None = None) -> App: """更新应用(向后兼容接口)""" service = AppService(db) return service.update_app(app_id=app_id, data=data, workspace_id=workspace_id) @@ -2021,12 +2026,15 @@ def delete_app(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None return service.delete_app(app_id=app_id, workspace_id=workspace_id) -def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.AgentConfigUpdate, workspace_id: uuid.UUID | None = None) -> AgentConfig: +def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.AgentConfigUpdate, + workspace_id: uuid.UUID | None = None) -> AgentConfig: """更新 Agent 配置(向后兼容接口)""" service = AppService(db) return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id) -def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate, workspace_id: uuid.UUID | None = None) -> WorkflowConfig: + +def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate, + workspace_id: uuid.UUID | None = None) -> WorkflowConfig: """更新 Agent 配置(向后兼容接口)""" service = AppService(db) return service.update_workflow_config(app_id=app_id, data=data, workspace_id=workspace_id) @@ -2040,6 +2048,7 @@ def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID service = AppService(db) return service.get_agent_config(app_id=app_id, workspace_id=workspace_id) + def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> WorkflowConfig: """获取 Agent 配置(向后兼容接口) @@ -2049,13 +2058,20 @@ def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UU return service.get_workflow_config(app_id=app_id, workspace_id=workspace_id) -def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease: +def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None, + version_name: str, release_notes: Optional[str] = None) -> AppRelease: """发布应用(向后兼容接口)""" service = AppService(db) - return service.publish(app_id=app_id, publisher_id=publisher_id,version_name = version_name, workspace_id=workspace_id, release_notes=release_notes) + return service.publish(app_id=app_id, publisher_id=publisher_id, version_name=version_name, + workspace_id=workspace_id, release_notes=release_notes) -def get_current_release(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> Optional[AppRelease]: +def get_current_release( + db: Session, + *, + app_id: uuid.UUID, + workspace_id: uuid.UUID | None = None +) -> Optional[AppRelease]: """获取当前发布版本(向后兼容接口)""" service = AppService(db) return service.get_current_release(app_id=app_id, workspace_id=workspace_id) @@ -2074,16 +2090,16 @@ def rollback(db: Session, *, app_id: uuid.UUID, version: int, workspace_id: uuid def list_apps( - db: Session, - *, - workspace_id: uuid.UUID, - type: Optional[str] = None, - visibility: Optional[str] = None, - status: Optional[str] = None, - search: Optional[str] = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, + db: Session, + *, + workspace_id: uuid.UUID, + type: Optional[str] = None, + visibility: Optional[str] = None, + status: Optional[str] = None, + search: Optional[str] = None, + include_shared: bool = True, + page: int = 1, + pagesize: int = 10, ) -> Tuple[List[App], int]: """列出应用(向后兼容接口)""" service = AppService(db) @@ -2100,9 +2116,9 @@ def list_apps( def get_apps_by_ids( - db: Session, - app_ids: List[str], - workspace_id: uuid.UUID + db: Session, + app_ids: List[str], + workspace_id: uuid.UUID ) -> List[App]: """根据ID列表获取应用(向后兼容接口)""" service = AppService(db) @@ -2112,14 +2128,14 @@ def get_apps_by_ids( # ==================== 向后兼容的函数接口 ==================== async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None + db: Session, + *, + app_id: uuid.UUID, + message: str, + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """试运行 Agent(向后兼容接口)""" service = AppService(db) @@ -2134,30 +2150,28 @@ async def draft_run( async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None + db: Session, + *, + app_id: uuid.UUID, + message: str, + conversation_id: Optional[str] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + workspace_id: Optional[uuid.UUID] = None ): """试运行 Agent 流式返回(向后兼容接口)""" service = AppService(db) async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id + app_id=app_id, + message=message, + conversation_id=conversation_id, + user_id=user_id, + variables=variables, + workspace_id=workspace_id ): yield event - - # ==================== 依赖注入函数 ==================== def get_app_service( diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index 50773b91..601d2921 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -711,45 +711,32 @@ class EmotionAnalyticsService: end_user_id: str, db: Session, ) -> Optional[Dict[str, Any]]: - """从缓存获取个性化情绪建议 + """从 Redis 缓存获取个性化情绪建议 Args: end_user_id: 宿主ID(用户组ID) - db: 数据库会话 + db: 数据库会话(保留参数以保持接口兼容性) Returns: Dict: 缓存的建议数据,如果不存在或已过期返回 None """ try: - from app.repositories.emotion_suggestions_cache_repository import ( - EmotionSuggestionsCacheRepository, - ) + from app.cache.memory.emotion_memory import EmotionMemoryCache - logger.info(f"尝试从缓存获取情绪建议: user={end_user_id}") + logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}") - cache_repo = EmotionSuggestionsCacheRepository(db) - cache = cache_repo.get_by_end_user_id(end_user_id) + # 从 Redis 获取缓存 + cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id) - if cache is None: - logger.info(f"用户 {end_user_id} 的建议缓存不存在") + if cached_data is None: + logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期") return None - # 检查是否过期 - if cache_repo.is_expired(cache): - logger.info(f"用户 {end_user_id} 的建议缓存已过期") - return None - - logger.info(f"成功从缓存获取建议: user={end_user_id}") - - return { - "health_summary": cache.health_summary, - "suggestions": cache.suggestions, - "generated_at": cache.generated_at.isoformat(), - "cached": True - } + logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}") + return cached_data except Exception as e: - logger.error(f"从缓存获取建议失败: {str(e)}", exc_info=True) + logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True) return None async def save_suggestions_cache( @@ -759,30 +746,33 @@ class EmotionAnalyticsService: db: Session, expires_hours: int = 24 ) -> None: - """保存建议到缓存 + """保存建议到 Redis 缓存 Args: end_user_id: 宿主ID(用户组ID) suggestions_data: 建议数据 - db: 数据库会话 - expires_hours: 过期时间(小时) + db: 数据库会话(保留参数以保持接口兼容性) + expires_hours: 过期时间(小时),默认24小时 """ try: - from app.repositories.emotion_suggestions_cache_repository import ( - EmotionSuggestionsCacheRepository, + from app.cache.memory.emotion_memory import EmotionMemoryCache + + logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时") + + # 计算过期时间(秒) + expire_seconds = expires_hours * 3600 + + # 保存到 Redis + success = await EmotionMemoryCache.set_emotion_suggestions( + user_id=end_user_id, + suggestions_data=suggestions_data, + expire=expire_seconds ) - logger.info(f"保存建议到缓存: user={end_user_id}") - - cache_repo = EmotionSuggestionsCacheRepository(db) - cache_repo.create_or_update( - end_user_id=end_user_id, - health_summary=suggestions_data["health_summary"], - suggestions=suggestions_data["suggestions"], - expires_hours=expires_hours - ) - - logger.info(f"建议缓存保存成功: user={end_user_id}") + if success: + logger.info(f"建议缓存保存成功: user={end_user_id}") + else: + logger.warning(f"建议缓存保存失败: user={end_user_id}") except Exception as e: logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True) diff --git a/api/app/services/implicit_memory_service.py b/api/app/services/implicit_memory_service.py index c98f14bc..106fa808 100644 --- a/api/app/services/implicit_memory_service.py +++ b/api/app/services/implicit_memory_service.py @@ -418,48 +418,32 @@ class ImplicitMemoryService: end_user_id: str, db: Session ) -> Optional[dict]: - """从缓存获取完整用户画像 + """从 Redis 缓存获取完整用户画像 Args: end_user_id: 终端用户ID - db: 数据库会话 + db: 数据库会话(保留参数以保持接口兼容性) Returns: Dict: 缓存的画像数据,如果不存在或已过期返回 None """ try: - from app.repositories.implicit_memory_cache_repository import ( - ImplicitMemoryCacheRepository, - ) + from app.cache.memory.implicit_memory import ImplicitMemoryCache - logger.info(f"尝试从缓存获取用户画像: user={end_user_id}") + logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}") - cache_repo = ImplicitMemoryCacheRepository(db) - cache = cache_repo.get_by_end_user_id(end_user_id) + # 从 Redis 获取缓存 + cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id) - if cache is None: - logger.info(f"用户 {end_user_id} 的画像缓存不存在") + if cached_data is None: + logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") return None - # 检查是否过期 - if cache_repo.is_expired(cache): - logger.info(f"用户 {end_user_id} 的画像缓存已过期") - return None - - logger.info(f"成功从缓存获取用户画像: user={end_user_id}") - - return { - "end_user_id": cache.end_user_id, - "preferences": cache.preferences, - "portrait": cache.portrait, - "interest_areas": cache.interest_areas, - "habits": cache.habits, - "generated_at": cache.generated_at.isoformat(), - "cached": True - } + logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}") + return cached_data except Exception as e: - logger.error(f"从缓存获取用户画像失败: {str(e)}", exc_info=True) + logger.error(f"从 Redis 缓存获取用户画像失败: {str(e)}", exc_info=True) return None async def save_profile_cache( @@ -469,32 +453,33 @@ class ImplicitMemoryService: db: Session, expires_hours: int = 168 # 默认7天 ) -> None: - """保存用户画像到缓存 + """保存用户画像到 Redis 缓存 Args: end_user_id: 终端用户ID profile_data: 画像数据 - db: 数据库会话 + db: 数据库会话(保留参数以保持接口兼容性) expires_hours: 过期时间(小时),默认168小时(7天) """ try: - from app.repositories.implicit_memory_cache_repository import ( - ImplicitMemoryCacheRepository, + from app.cache.memory.implicit_memory import ImplicitMemoryCache + + logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时") + + # 计算过期时间(秒) + expire_seconds = expires_hours * 3600 + + # 保存到 Redis + success = await ImplicitMemoryCache.set_user_profile( + user_id=end_user_id, + profile_data=profile_data, + expire=expire_seconds ) - logger.info(f"保存用户画像到缓存: user={end_user_id}") - - cache_repo = ImplicitMemoryCacheRepository(db) - cache_repo.create_or_update( - end_user_id=end_user_id, - preferences=profile_data["preferences"], - portrait=profile_data["portrait"], - interest_areas=profile_data["interest_areas"], - habits=profile_data["habits"], - expires_hours=expires_hours - ) - - logger.info(f"用户画像缓存保存成功: user={end_user_id}") + if success: + logger.info(f"用户画像缓存保存成功: user={end_user_id}") + else: + logger.warning(f"用户画像缓存保存失败: user={end_user_id}") except Exception as e: logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index f0756764..c9230a26 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,30 +9,27 @@ import os import re import time import uuid - from typing import Any, AsyncGenerator, Dict, List, Optional - import redis +from langchain_core.messages import HumanMessage + from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config +from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_mcp_adapters.tools import load_mcp_tools from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -50,21 +47,17 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - - - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): - messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') - countext = re.findall(r'"status": "(.*?)",', messages)[0] + def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): duration = time.time() - start_time - if countext == 'success': + if str(messages) == 'success': logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, duration=duration, details={"message_length": len(message)}) - return countext + return context else: logger.warning(f"Write operation failed for group {group_id}") @@ -80,9 +73,9 @@ class MemoryAgentService: ) raise ValueError(f"写入失败: {messages}") - - + + def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -119,15 +112,15 @@ class MemoryAgentService: return True return False - + async def get_health_status(self) -> Dict: """ Get latest health status from Redis cache - + Returns health status information written by Celery periodic task """ logger.info("Checking health status") - + client = redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, @@ -135,34 +128,51 @@ class MemoryAgentService: password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None ) payload = client.hgetall("memsci:health:read_service") or {} - + if payload: # decode bytes to str decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()} status = decoded.get("status", "unknown") else: status = "unknown" - + + # Add database connection pool status + try: + from app.db import get_pool_status + pool_status = get_pool_status() + logger.info(f"Database pool status: {pool_status}") + + # Check if pool usage is too high + if pool_status.get("usage_percent", 0) > 80: + logger.warning(f"High database pool usage: {pool_status['usage_percent']}%") + status = "warning" + + except Exception as e: + logger.error(f"Failed to get pool status: {e}") + pool_status = {"error": str(e)} + logger.info(f"Health status: {status}") - return {"status": status} + return { + "status": status, + "database_pool": pool_status + } def get_log_content(self) -> str: """ Read and return agent service log file content - - Returns cleaned log content using the same cleaning logic as transmission mode + + Returns cleaned log content using the same cleaning logic as transmission mode Returns cleaned log content using the same cleaning logic as transmission mode """ logger.info("Reading log file") - # Use project root directory for logs - # Get the project root (redbear-mem directory) + current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + summer = '' with open(log_path, "r", encoding="utf-8") as infile: @@ -176,83 +186,83 @@ class MemoryAgentService: logger.info(f"Log content retrieved, size: {len(summer)} bytes") return summer - + async def stream_log_content(self) -> AsyncGenerator[str, None]: """ Stream log content in real-time using Server-Sent Events (SSE) - + This method establishes a streaming connection and transmits log entries as they are written to the log file. It uses the LogStreamer to watch the file and yields SSE-formatted messages. - + Yields: SSE-formatted strings with the following event types: - log: Contains log content and timestamp - keepalive: Periodic keepalive messages to maintain connection - error: Error information if streaming fails - done: Indicates streaming has completed - + Raises: FileNotFoundError: If log file doesn't exist at stream start Exception: For other unexpected errors during streaming """ logger.info("Starting log content streaming") - + # Get log file path - use project root directory current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + # Check if file exists before starting stream if not os.path.exists(log_path): logger.error(f"Log file not found: {log_path}") # Send error event in SSE format yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n" return - + streamer = None try: # Initialize LogStreamer with keepalive interval from settings (default 300 seconds) keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300) streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval) - + logger.info(f"LogStreamer initialized for {log_path}") - + # Stream log content using read_existing_and_stream to get all existing content first async for message in streamer.read_existing_and_stream(): event_type = message.get("event") data = message.get("data") - + # Format as SSE message # SSE format: "event: \ndata: \n\n" sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - + logger.debug(f"Streaming event: {event_type}") yield sse_message - + # If error or done event, stop streaming if event_type in ["error", "done"]: logger.info(f"Stream ended with event: {event_type}") break - + except FileNotFoundError as e: logger.error(f"Log file not found during streaming: {e}") yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n" - + except Exception as e: logger.error(f"Unexpected error during log streaming: {e}", exc_info=True) yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n" - + finally: # Resource cleanup logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - + async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id - + Args: group_id: Group identifier (also used as end_user_id) message: Message to write @@ -260,10 +270,10 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Write operation result status - + Raises: ValueError: If config loading fails or write operation fails """ @@ -279,7 +289,7 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + import time start_time = time.time() @@ -294,61 +304,49 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) - + raise ValueError(error_msg) - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) - return result - else: - async with client.session("data_flow") as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - workflow_errors = [] # Track errors from workflow - - # Pass memory_config to the graph workflow - async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph: - logger.debug("Write graph created successfully") + try: + if storage_type == "rag": + result = await write_rag(group_id, message, user_rag_memory_id) + return result + else: + async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, + "memory_config": memory_config} - async for event in graph.astream( - {"messages": message, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.error(f"Write workflow failed with errors: {error_details}") - - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_details - ) - - raise ValueError(f"Write workflow failed: {error_details}") - - return self.writer_messages_deal(messages, start_time, group_id, config_id, message) - + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Write operation failed: {str(e)}" + logger.error(error_msg) + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + raise ValueError(error_msg) + + + + async def read_memory( self, group_id: str, @@ -362,12 +360,12 @@ class MemoryAgentService: ) -> Dict: """ Process read operation with config_id - + search_switch values: - "0": Requires verification - "1": No verification, direct split - "2": Direct answer based on context - + Args: group_id: Group identifier (also used as end_user_id) message: User message @@ -377,18 +375,17 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Dict with 'answer' and 'intermediate_outputs' keys - + Raises: ValueError: If config loading fails """ import time start_time = time.time() - ori_message=message - end_user_id=group_id + # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -410,6 +407,7 @@ class MemoryAgentService: except ImportError: audit_logger = None + try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( @@ -440,326 +438,128 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() + try: + async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "group_id": group_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + for node_name, node_data in update_event.items(): + # if 'save_neo4j' == node_name: + # massages = node_data + print(f"处理节点: {node_name}") - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + result = reorder_output_results(optimized_outputs) - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found - - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass - - # 记录成功的操作 - total_duration = time.time() - start_time - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") + # Log successful operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=duration + ) + return { + "answer": summary, + "intermediate_outputs": result + } + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Read operation failed: {str(e)}" + logger.error(error_msg) if audit_logger: + duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } + duration=duration, + error=error_msg ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=True, - duration=total_duration, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) - } - ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") + raise ValueError(error_msg) - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - - def _create_intermediate_key(self, output: Dict) -> str: - """ - Create a unique key for an intermediate output to detect duplicates. - - Args: - output: Intermediate output dictionary - - Returns: - Unique string key for this output - """ - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - # Use type + original query as key - return f"split:{output.get('original_query', '')}" - elif output_type == 'problem_extension': - # Use type + original query as key - return f"extension:{output.get('original_query', '')}" - elif output_type == 'search_result': - # Use type + query + index as key - return f"search:{output.get('query', '')}:{output.get('index', 0)}" - elif output_type == 'retrieval_summary': - # Use type + query as key - return f"summary:{output.get('query', '')}" - elif output_type == 'verification': - # Use type + query as key - return f"verification:{output.get('query', '')}" - elif output_type == 'input_summary': - # Use type + query as key - return f"input_summary:{output.get('query', '')}" - else: - # Fallback: use JSON representation - import json - return json.dumps(output, sort_keys=True) - - def _format_intermediate_output(self, output: Dict) -> Dict: - """Format intermediate output for frontend display.""" - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - return { - 'type': 'problem_split', - 'title': '问题拆分', - 'data': output.get('data', []), - 'original_query': output.get('original_query', '') - } - elif output_type == 'problem_extension': - return { - 'type': 'problem_extension', - 'title': '问题扩展', - 'data': output.get('data', {}), - 'original_query': output.get('original_query', '') - } - elif output_type == 'search_result': - return { - 'type': 'search_result', - 'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})', - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results', ''), - 'index': output.get('index', 0), - 'total': output.get('total', 0) - } - elif output_type == 'retrieval_summary': - return { - 'type': 'retrieval_summary', - 'title': '检索总结', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - elif output_type == 'verification': - return { - 'type': 'verification', - 'title': '数据验证', - 'result': output.get('result', 'unknown'), - 'reason': output.get('reason', ''), - 'query': output.get('query', ''), - 'verified_count': output.get('verified_count', 0) - } - elif output_type == 'input_summary': - return { - 'type': 'input_summary', - 'title': '快速答案', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - else: - return output - async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ Determine the type of user message (read or write) Updated to eliminate global variables in favor of explicit parameters. - + Args: message: User message to classify config_id: Configuration ID to load LLM model from database db: Database session - + Returns: Type classification result """ logger.info("Classifying message type") - + # Load configuration to get LLM model ID config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - + status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status - + # ==================== 新增的三个接口方法 ==================== - + async def get_knowledge_type_stats( self, end_user_id: Optional[str] = None, @@ -772,13 +572,13 @@ class MemoryAgentService: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) 3. total: 所有类型的总和 - + 参数: - end_user_id: 用户组ID(可选,未提供时 memory 统计为 0) - only_active: 是否仅统计有效记录 - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - db: 数据库会话 - + 返回格式: { "General": count, @@ -790,18 +590,18 @@ class MemoryAgentService: } """ result = {} - + # 1. 统计 PostgreSQL 中的知识库类型 try: if db is None: from app.db import get_db db_gen = get_db() db = next(db_gen) - + # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 - + # 如果提供了 workspace_id,则按 workspace_id 过滤 if current_workspace_id: # 构建查询条件 @@ -809,47 +609,48 @@ class MemoryAgentService: Knowledge.type, func.count(Knowledge.id).label('count') ).filter(Knowledge.workspace_id == current_workspace_id) - + # 检查 Knowledge 模型是否有 status 字段 if only_active and hasattr(Knowledge, 'status'): query = query.filter(Knowledge.status == 1) - + # 按类型分组 type_counts = query.group_by(Knowledge.type).all() - + # 只填充标准类型的统计值,忽略其他类型 valid_types = {kb_type.value for kb_type in KnowledgeType} for type_name, count in type_counts: if type_name in valid_types: result[type_name] = count - + logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}") else: # 没有提供 workspace_id,所有知识库类型返回 0 logger.info("未提供 workspace_id,知识库类型统计全部为 0") - + except Exception as e: logger.error(f"知识库类型统计失败: {e}") raise Exception(f"知识库类型统计失败: {e}") - + # 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数) try: if current_workspace_id: # 获取当前空间下的所有宿主 from app.repositories import app_repository, end_user_repository from app.schemas.app_schema import App as AppSchema - + from app.schemas.end_user_schema import EndUser as EndUserSchema + # 查询应用并转换为 Pydantic 模型 apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) apps = [AppSchema.model_validate(h) for h in apps_orm] app_ids = [app.id for app in apps] - + # 获取所有宿主 end_users = [] for app_id in app_ids: end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) end_users.extend(h for h in end_user_orm_list) - + # 统计所有宿主的 Chunk 总数 total_chunks = 0 for end_user in end_users: @@ -864,27 +665,27 @@ class MemoryAgentService: chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 total_chunks += chunk_count logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - + result["memory"] = total_chunks logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") else: # 没有 workspace_id 时,返回 0 result["memory"] = 0 logger.info("未提供 workspace_id,memory 统计为 0") - + except Exception as e: logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) # 如果 Neo4j 查询失败,memory 设为 0 result["memory"] = 0 - + # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + result.get("Folder", 0) ) - + return result @@ -895,11 +696,11 @@ class MemoryAgentService: ) -> List[Dict[str, Any]]: """ 获取指定用户的热门记忆标签 - + 参数: - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 - limit: 返回标签数量限制 - + 返回格式: [ {"name": "标签名", "frequency": 频次}, @@ -928,13 +729,13 @@ class MemoryAgentService: 1. 用户名字(直接使用 end_user_name) 2. 用户标签(从摘要中用LLM总结3个标签) 3. 热门记忆标签(从hot_memory_tags获取前4个) - + 参数: - end_user_id: 用户ID(可选) - current_user_id: 当前登录用户的ID(保留参数) - llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成) - db: 数据库会话(可选) - + 返回格式: { "name": "用户名", @@ -947,13 +748,13 @@ class MemoryAgentService: } """ result = {} - + # 1. 根据 end_user_id 获取 end_user_name try: if end_user_id and db: from app.repositories import end_user_repository from app.schemas.end_user_schema import EndUser as EndUserSchema - + end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id) if end_user_orm: end_user = EndUserSchema.model_validate(end_user_orm) @@ -965,14 +766,14 @@ class MemoryAgentService: except Exception as e: logger.error(f"Failed to get end_user_name: {e}") end_user_name = "默认用户" - + result["name"] = end_user_name logger.debug(f"The end_user is: {end_user_name}") - + # 2. 使用LLM从语句和实体中提取标签 try: connector = Neo4jConnector() - + # 查询该用户的语句 query = ( "MATCH (s:Statement) " @@ -982,7 +783,7 @@ class MemoryAgentService: ) rows = await connector.execute_query(query, group_id=end_user_id) statements = [r.get("statement", "") for r in rows if r.get("statement")] - + # 查询该用户的热门实体 entity_query = ( "MATCH (e:ExtractedEntity) " @@ -992,9 +793,9 @@ class MemoryAgentService: ) entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] - + await connector.close() - + if not statements or not llm_id: result["tags"] = [] if not llm_id and statements: @@ -1003,16 +804,16 @@ class MemoryAgentService: # 构建摘要文本 summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}" logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") - + # 使用LLM提取标签 with get_db_context() as db: factory = MemoryClientFactory(db) llm_client = factory.get_llm_client(llm_id) - + # 定义标签提取的结构 class UserTags(BaseModel): tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") - + messages = [ { "role": "system", @@ -1023,20 +824,20 @@ class MemoryAgentService: "content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}" } ] - + user_tags = await llm_client.response_structured( messages=messages, response_model=UserTags ) - + result["tags"] = user_tags.tags logger.debug(f"Extracted tags: {user_tags.tags}") - + except Exception as e: # 如果提取失败,使用默认值 logger.error(f"Failed to extract user tags: {e}") result["tags"] = [] - + try: # 3. 获取热门记忆标签(前4个) connector = Neo4jConnector() @@ -1049,18 +850,18 @@ class MemoryAgentService: "ORDER BY frequency DESC LIMIT 4" ) hot_tag_rows = await connector.execute_query( - hot_tag_query, - group_id=end_user_id, + hot_tag_query, + group_id=end_user_id, names_to_exclude=names_to_exclude ) await connector.close() - + result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows] logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags") except Exception as e: logger.error(f"Failed to get hot tags: {e}") result["hot_tags"] = [] - + return result async def stream_log_content(self) -> AsyncGenerator[str, None]: @@ -1135,79 +936,40 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic -# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]: -# """ -# Parse and return API documentation - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Dict containing parsed API documentation or error information -# """ -# try: -# target = file_path or get_default_docs_path() - -# if not os.path.isfile(target): -# return { -# "success": False, -# "msg": "API文档文件不存在", -# "error_code": "DOC_NOT_FOUND", -# "data": {"path": target} -# } - -# data = parse_api_docs(target) -# return { -# "success": True, -# "msg": "解析成功", -# "data": data -# } -# except Exception as e: -# logger.error(f"Failed to parse API docs: {e}") -# return { -# "success": False, -# "msg": "解析失败", -# "error_code": "DOC_PARSE_ERROR", -# "data": {"error": str(e)} -# } - - def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: """ 获取终端用户关联的记忆配置 - + 通过以下流程获取配置: 1. 根据 end_user_id 获取用户的 app_id 2. 获取该应用的最新发布版本 3. 从发布版本的 config 字段中提取 memory_config_id - 4. 根据 memory_config_id 查询配置名称 - + Args: end_user_id: 终端用户ID db: 数据库会话 - + Returns: - 包含 memory_config_id、config_name 和相关信息的字典 - + 包含 memory_config_id 和相关信息的字典 + Raises: ValueError: 当终端用户不存在或应用未发布时 """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser from sqlalchemy import select - + logger.info(f"Getting connected config for end_user: {end_user_id}") - + # 1. 获取 end_user 及其 app_id end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: logger.warning(f"End user not found: {end_user_id}") raise ValueError(f"终端用户不存在: {end_user_id}") - + app_id = end_user.app_id logger.debug(f"Found end_user app_id: {app_id}") - + # 2. 获取该应用的最新发布版本 stmt = ( select(AppRelease) @@ -1215,170 +977,135 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An .order_by(AppRelease.version.desc()) ) latest_release = db.scalars(stmt).first() - + if not latest_release: logger.warning(f"No active release found for app: {app_id}") raise ValueError(f"应用未发布: {app_id}") - + logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}") - + # 3. 从 config 中提取 memory_config_id config = latest_release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - - # 4. 根据 memory_config_id 查询配置名称 - config_name = None - if memory_config_id: - try: - # memory_config_id 可能是整数或字符串,需要转换 - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - if data_config: - config_name = data_config.config_name - logger.debug(f"Found config_name: {config_name} for config_id: {config_id}") - else: - logger.warning(f"DataConfig not found for config_id: {config_id}") - except (ValueError, TypeError) as e: - logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}") - + result = { "end_user_id": str(end_user_id), "app_id": str(app_id), "release_id": str(latest_release.id), "release_version": latest_release.version, - "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_id": memory_config_id } - - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}") + + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") return result def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]: """ - 批量获取多个终端用户关联的记忆配置 - - 通过优化的查询减少数据库往返次数: - 1. 一次性查询所有 end_user 及其 app_id - 2. 批量查询所有相关的 app_release - 3. 批量查询所有相关的 data_config - + 批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数) + + 通过以下流程获取配置: + 1. 批量查询所有 end_user_id 对应的 app_id + 2. 批量获取这些应用的最新发布版本 + 3. 从发布版本的 config 字段中提取 memory_config_id + Args: end_user_ids: 终端用户ID列表 db: 数据库会话 - + Returns: - 字典,key 为 end_user_id,value 为配置信息字典 - 对于查询失败的用户,value 包含 error 字段 + 字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典 + 格式: { + "user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"}, + "user_id_2": {"memory_config_id": None, "memory_config_name": None}, + ... + } """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser + from app.models.memory_config_model import MemoryConfig from sqlalchemy import select - - logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users") - + + logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") + result = {} - + + # 如果列表为空,直接返回空字典 + if not end_user_ids: + return result + # 1. 批量查询所有 end_user 及其 app_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - # 构建 end_user_id -> end_user 的映射 - end_user_map = {str(user.id): user for user in end_users} + # 创建 end_user_id -> app_id 的映射 + user_to_app = {str(eu.id): eu.app_id for eu in end_users} - # 记录不存在的用户 - for user_id in end_user_ids: - if user_id not in end_user_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"终端用户不存在: {user_id}" - } - - if not end_users: - logger.warning("No valid end users found") + # 记录未找到的用户 + found_user_ids = set(user_to_app.keys()) + missing_user_ids = set(end_user_ids) - found_user_ids + if missing_user_ids: + logger.warning(f"End users not found: {missing_user_ids}") + for user_id in missing_user_ids: + result[user_id] = {"memory_config_id": None, "memory_config_name": None} + + # 2. 批量获取所有相关应用的最新发布版本 + app_ids = list(user_to_app.values()) + if not app_ids: return result - - # 2. 批量查询所有相关应用的最新发布版本 - app_ids = [user.app_id for user in end_users] - - # 使用子查询找到每个 app 的最新版本 - from sqlalchemy import and_ - - # 查询所有相关的活跃发布版本 - releases = db.query(AppRelease).filter( - and_( - AppRelease.app_id.in_(app_ids), - AppRelease.is_active.is_(True) - ) - ).order_by(AppRelease.app_id, AppRelease.version.desc()).all() - - # 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本) - app_release_map = {} + + # 查询所有活跃的发布版本 + stmt = ( + select(AppRelease) + .where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True)) + .order_by(AppRelease.app_id, AppRelease.version.desc()) + ) + releases = db.scalars(stmt).all() + + # 创建 app_id -> latest_release 的映射(每个 app 只保留最新版本) + app_to_release = {} for release in releases: - app_id_str = str(release.app_id) - if app_id_str not in app_release_map: - app_release_map[app_id_str] = release - - # 3. 收集所有 memory_config_id + if release.app_id not in app_to_release: + app_to_release[release.app_id] = release + + # 3. 收集所有 memory_config_id 并批量查询配置名称 memory_config_ids = [] - for release in app_release_map.values(): - config = release.config or {} - memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - memory_config_ids.append(config_id) - except (ValueError, TypeError): - pass - - # 4. 批量查询所有 data_config - config_name_map = {} + for end_user_id, app_id in user_to_app.items(): + release = app_to_release.get(app_id) + if release: + config = release.config or {} + memory_obj = config.get('memory', {}) + memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + if memory_config_id: + memory_config_ids.append(memory_config_id) + + # 批量查询 memory_config_name + config_id_to_name = {} if memory_config_ids: - data_configs = db.query(DataConfig).filter( - DataConfig.config_id.in_(memory_config_ids) - ).all() - config_name_map = {config.config_id: config.config_name for config in data_configs} - - # 5. 组装结果 - for user in end_users: - user_id = str(user.id) - app_id = str(user.app_id) + memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} + + # 4. 构建最终结果 + for end_user_id, app_id in user_to_app.items(): + release = app_to_release.get(app_id) - # 检查是否有发布版本 - if app_id not in app_release_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"应用未发布: {app_id}" - } + if not release: + logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})") + result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} continue - - release = app_release_map[app_id] - - # 提取 memory_config_id + + # 从 config 中提取 memory_config_id config = release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - # 获取 config_name - config_name = None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - config_name = config_name_map.get(config_id) - except (ValueError, TypeError): - pass - - result[user_id] = { - "end_user_id": user_id, + # 获取配置名称 + memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None + + result[end_user_id] = { "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_name": memory_config_name } - - logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}") + + logger.info(f"Successfully retrieved {len(result)} connected configs") return result \ No newline at end of file diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 0f8fb569..46e42b46 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -120,10 +120,12 @@ class WorkspaceAppService: def _get_data_config(self, memory_content: str) -> Dict[str, Any]: """Retrieve data_comfig information based on memory_comtent""" try: - data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) - data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() - if data_config_result is None: - return None + data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) + + # data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) + # data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() + # if data_config_result is None: + # return None if data_config_result: return { @@ -206,6 +208,47 @@ class MemoryReflectionService: def __init__(self,db: Session = Depends(get_db)): self.db=db + async def start_text_reflection(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + try: + config_id = config_data.get("config_id") + api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}") + + if not config_data.get("enable_self_reflexion", False): + return { + "status": "跳过", + "message": "反思引擎未启用", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + config_data_id = config_data['config_id'] + reflection_config = WorkspaceAppService(self.db)._get_data_config(config_data_id) + if reflection_config is not None and reflection_config['enable_self_reflexion']: + reflection_config = self._create_reflection_config_from_data(reflection_config) + # 3. 执行反思引擎 + reflection_results = await self._execute_reflection_engine( + reflection_config, end_user_id + ) + return { + "status": "完成", + "message": "反思引擎执行完成", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": reflection_results + } + + except Exception as e: + config_id = config_data.get("config_id", "unknown") + api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") + return { + "status": "错误", + "message": f"启动反思失败: {str(e)}", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: """ @@ -237,16 +280,41 @@ class MemoryReflectionService: reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id) if reflection_config is not None and reflection_config['enable_self_reflexion']: reflection_config= self._create_reflection_config_from_data(reflection_config) - iteration_period=reflection_config.iteration_period + iteration_period = int(reflection_config.iteration_period) workspace_service = WorkspaceAppService(self.db) current_reflection_time = workspace_service.get_end_user_reflection_time(end_user_id) - reflection_time = datetime.fromisoformat(str(current_reflection_time)) - - current_time = datetime.now() - time_diff = current_time - reflection_time - hours_diff = int(time_diff.total_seconds() / 3600) - if iteration_period==hours_diff or current_reflection_time is None: + # 检查是否需要执行反思 + should_execute = False + hours_diff = 0 + + if current_reflection_time is None: + # 首次执行反思 + should_execute = True + api_logger.info(f"首次执行反思,end_user_id: {end_user_id}") + else: + # 计算时间差 + try: + if isinstance(current_reflection_time, str): + reflection_time = datetime.fromisoformat(current_reflection_time) + else: + reflection_time = current_reflection_time + + current_time = datetime.now() + time_diff = current_time - reflection_time + hours_diff = int(time_diff.total_seconds() / 3600) + + # 检查是否达到反思周期 + if hours_diff >= iteration_period: + should_execute = True + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时,达到周期 {iteration_period} 小时") + else: + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时,未达到周期 {iteration_period} 小时") + except (ValueError, TypeError) as e: + api_logger.warning(f"解析反思时间失败: {e},将执行反思") + should_execute = True + + if should_execute: api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") # 3. 执行反思引擎 reflection_results = await self._execute_reflection_engine( @@ -269,13 +337,15 @@ class MemoryReflectionService: } else: return { - "status": "等待中..", - "message": "反思引擎未开始执行执", + "status": "等待中", + "message": f"反思引擎未开始执行,距离下次执行还需 {iteration_period - hours_diff} 小时", "config_id": config_id, "end_user_id": end_user_id, "config_data": config_data, - "reflection_results": '' + "hours_since_last_reflection": hours_diff, + "next_reflection_in_hours": iteration_period - hours_diff } + except Exception as e: config_id = config_data.get("config_id", "unknown") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 974d5418..b7d5df02 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -4,7 +4,7 @@ import datetime import logging import uuid -from typing import Any, Annotated, AsyncGenerator +from typing import Any, Annotated, AsyncGenerator, Optional from deprecated import deprecated from fastapi import Depends @@ -14,15 +14,14 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.validator import validate_workflow_config from app.db import get_db -from app.models.conversation_model import Message from app.models.workflow_model import WorkflowConfig, WorkflowExecution -from app.repositories.conversation_repository import MessageRepository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) from app.schemas import DraftRunRequest +from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ class WorkflowService: self.config_repo = WorkflowConfigRepository(db) self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) - self.message_repo = MessageRepository(db) + self.conversation_service = ConversationService(db) # ==================== 配置管理 ==================== @@ -266,6 +265,7 @@ class WorkflowService: workflow_config_id: uuid.UUID, app_id: uuid.UUID, trigger_type: str, + release_id: uuid.UUID | None = None, triggered_by: uuid.UUID | None = None, conversation_id: uuid.UUID | None = None, input_data: dict[str, Any] | None = None @@ -273,6 +273,7 @@ class WorkflowService: """创建工作流执行记录 Args: + release_id: 应用发布 ID workflow_config_id: 工作流配置 ID app_id: 应用 ID trigger_type: 触发类型 @@ -289,6 +290,7 @@ class WorkflowService: execution = WorkflowExecution( workflow_config_id=workflow_config_id, app_id=app_id, + release_id=release_id, conversation_id=conversation_id, execution_id=execution_id, trigger_type=trigger_type, @@ -337,6 +339,7 @@ class WorkflowService: self, execution_id: str, status: str, + token_usage: int | None = None, output_data: dict[str, Any] | None = None, error_message: str | None = None, error_node_id: str | None = None @@ -346,6 +349,7 @@ class WorkflowService: Args: execution_id: 执行 ID status: 状态 + token_usage: token消耗 output_data: 输出数据 error_message: 错误信息 error_node_id: 出错节点 ID @@ -364,6 +368,8 @@ class WorkflowService: ) execution.status = status + if token_usage is not None: + execution.token_usage = token_usage if output_data is not None: execution.output_data = convert_uuids_to_str(output_data) if error_message is not None: @@ -414,12 +420,14 @@ class WorkflowService: payload: DraftRunRequest, config: WorkflowConfig, workspace_id: uuid.UUID, + release_id: uuid.UUID | None = None, ): """运行工作流 Args: - workspace_id: - config: + release_id: 发布 ID + workspace_id:工作空间 ID + config: 配置 payload: app_id: 应用 ID @@ -463,7 +471,8 @@ class WorkflowService: trigger_type="manual", triggered_by=None, conversation_id=conversation_id_uuid, - input_data=input_data + input_data=input_data, + release_id=release_id, ) # 3. 构建工作流配置字典 @@ -507,20 +516,20 @@ class WorkflowService: # 更新执行结果 if result.get("status") == "completed": + token_usage = result.get("token_usage", {}) or {} self.update_execution_status( execution.execution_id, "completed", - output_data=result + output_data=result, + token_usage=token_usage.get("total_tokens", None) ) final_messages = result.get("messages", [])[init_message_length:] for message in final_messages: - message_obj = Message( + self.conversation_service.add_message( conversation_id=conversation_id_uuid, role=message["role"], - content=message["content"], + content=message["content"] ) - self.message_repo.add_message(message_obj) - self.db.commit() logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: @@ -562,10 +571,12 @@ class WorkflowService: payload: DraftRunRequest, config: WorkflowConfig, workspace_id: uuid.UUID, + release_id: Optional[uuid.UUID] = None, ): """运行工作流(流式) Args: + release_id: 发布id workspace_id: app_id: 应用 ID payload: 请求对象(包含 message, variables, conversation_id 等) @@ -611,7 +622,8 @@ class WorkflowService: trigger_type="manual", triggered_by=None, conversation_id=conversation_id_uuid, - input_data=input_data + input_data=input_data, + release_id=release_id, ) # 3. 构建工作流配置字典 @@ -653,21 +665,21 @@ class WorkflowService: if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") + token_usage = event.get("data", {}).get("token_usage", {}) or {} if status == "completed": self.update_execution_status( execution.execution_id, "completed", - output_data=event.get("data") + output_data=event.get("data"), + token_usage=token_usage.get("total_tokens", None) ) final_messages = event.get("data", {}).get("messages", [])[init_message_length:] for message in final_messages: - message_obj = Message( + self.conversation_service.add_message( conversation_id=conversation_id_uuid, role=message["role"], - content=message["content"], + content=message["content"] ) - self.message_repo.add_message(message_obj) - self.db.commit() logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": @@ -784,10 +796,12 @@ class WorkflowService: # 更新执行结果 if result.get("status") == "completed": + token_usage = result.get("data").get("token_usage", {}) or {} self.update_execution_status( execution.execution_id, "completed", - output_data=result.get("node_outputs", {}) + output_data=result.get("node_outputs", {}), + token_usage=token_usage.get("total_tokens", None) ) else: self.update_execution_status( @@ -882,13 +896,14 @@ class WorkflowService: ): # 直接转发事件(executor 已经返回正确格式) if event.get("event") == "workflow_end": - + token_usage = event.get("data").get("token_usage", {}) or {} status = event.get("data", {}).get("status") if status == "completed": self.update_execution_status( execution_id, "completed", - output_data=event.get("data") + output_data=event.get("data"), + token_usage=token_usage.get("total_tokens", None) ) elif status == "failed": self.update_execution_status( diff --git a/api/app/tasks.py b/api/app/tasks.py index 28a882b7..fba9f290 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,27 +1,27 @@ import asyncio -import trio import json import os +import re import time import uuid from datetime import datetime, timezone from math import ceil from typing import Any, Dict, List, Optional -import re import redis import requests +import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.sequence2txt_model import QWenSeq2txt from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.prompts.generator import question_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, @@ -486,6 +486,10 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage Raises: Exception on failure """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}") start_time = time.time() # Resolve config_id if None @@ -506,8 +510,14 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage async def _run() -> str: db = next(get_db()) try: + logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory") service = MemoryAgentService() - return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + logger.info(f"[CELERY WRITE] Write completed successfully: {result}") + return result + except Exception as e: + logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True) + raise finally: db.close() @@ -532,6 +542,8 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time + logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + return { "status": "SUCCESS", "result": result, @@ -548,6 +560,9 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage detailed_error = "; ".join(error_messages) else: detailed_error = str(e) + + logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", exc_info=True) + return { "status": "FAILURE", "error": detailed_error, diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index 14de4a73..2cf0f9b1 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -53,7 +53,7 @@ nodes: type: end name: 结束 config: - output: "{{ llm_qa.output }}" + output: "{{llm_qa.output}}" position: x: 900 y: 100 diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 4a35a4cc..514e4565 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -120,12 +120,9 @@ def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig: def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig: config_dict = release.config - with get_db_read() as db: - source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id) - source_config_id = source_config.id config = WorkflowConfig( - id=source_config_id, + id=config_dict.get("id"), app_id=release.app_id, nodes=config_dict.get("nodes", []), edges=config_dict.get("edges", []), diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 8470a5d1..8bc19f3a 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,32 +1,5 @@ -version: '3.9' - services: - # MCP Server - standalone service - mcp-server: - image: redbear-mem-open:latest - container_name: mcp-server - ports: - - "8081:8081" # MCP server port - env_file: - - .env - environment: - - SERVER_IP=0.0.0.0 # Bind to all interfaces - volumes: - - ./files:/files - - /etc/localtime:/etc/localtime:ro - command: python -m app.core.memory.agent.mcp_server.server - healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 30s - restart: unless-stopped - networks: - - default - - celery - - # FastAPI application - connects to MCP server + # FastAPI application api: image: redbear-mem-open:latest container_name: api @@ -35,37 +8,31 @@ services: env_file: - .env environment: - - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name - - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces + - SERVER_IP=0.0.0.0 + # 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位 + # - MCP_SERVER_URL= volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug - depends_on: - mcp-server: - condition: service_healthy restart: unless-stopped networks: - default - celery - # Celery worker - connects to MCP server + # Celery worker worker: image: redbear-mem-open:latest container_name: worker env_file: - .env - environment: - - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro command: celery -A app.celery_worker.celery_app worker --loglevel=info - depends_on: - mcp-server: - condition: service_healthy restart: unless-stopped networks: - celery + networks: - celery: \ No newline at end of file + celery: diff --git a/api/migrations/versions/8cd790908f92_202601191615.py b/api/migrations/versions/8cd790908f92_202601191615.py new file mode 100644 index 00000000..8e4624ee --- /dev/null +++ b/api/migrations/versions/8cd790908f92_202601191615.py @@ -0,0 +1,34 @@ +"""202601191615 + +Revision ID: 8cd790908f92 +Revises: 1fd7d0e703b3 +Create Date: 2026-01-19 16:15:35.058649 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8cd790908f92' +down_revision: Union[str, None] = '1fd7d0e703b3' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('workflow_executions', sa.Column('release_id', sa.UUID(), nullable=True)) + op.create_index(op.f('ix_workflow_executions_release_id'), 'workflow_executions', ['release_id'], unique=False) + op.create_foreign_key(None, 'workflow_executions', 'app_releases', ['release_id'], ['id'], ondelete='CASCADE') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'workflow_executions', type_='foreignkey') + op.drop_index(op.f('ix_workflow_executions_release_id'), table_name='workflow_executions') + op.drop_column('workflow_executions', 'release_id') + # ### end Alembic commands ### diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 0ac14451..bbd9f6b0 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -18,174 +18,180 @@ import type { TestParams } from '@/views/MemoryConversation' import type { EndUser } from '@/views/UserMemoryDetail/types' import { handleSSE, type SSEMessage } from '@/utils/stream' -// 记忆对话 +// Memory conversation export const readService = (query: TestParams) => { return request.post('/memory/read_service', query) } -/****************** 记忆看板 相关接口 *******************************/ -// 记忆看板-记忆总量 +/****************** Memory Dashboard APIs *******************************/ +// Memory Dashboard - Total memory count export const getTotalMemoryCount = () => { return request.get(`/dashboard/total_memory_count`) } -// 记忆看板-知识库类型分布 +// Memory Dashboard - Knowledge base type distribution export const getKbTypes = () => { return request.get(`/memory/stats/types`) } -// 记忆看板-热门记忆标签 +// Memory Dashboard - Hot memory tags export const getHotMemoryTags = () => { return request.get(`/memory-storage/analytics/hot_memory_tags`) } -// 记忆看板-最近活动统计 +// Memory Dashboard - Recent activity statistics export const getRecentActivityStats = () => { return request.get(`/memory-storage/analytics/recent_activity_stats`) } -// 记忆看板-记忆增长趋势 +// Memory Dashboard - Memory growth trend export const getMemoryIncrement = (limit: number) => { return request.get(`/dashboard/memory_increment`, { limit }) } -// 记忆看板-API调用趋势 +// Memory Dashboard - API call trend export const getApiTrend = () => { return request.get(`/dashboard/api_increment`) } -// 记忆看板-总数据 +// Memory Dashboard - Total data export const getDashboardData = () => { return request.get(`/dashboard/dashboard_data`) } -/*************** end 记忆看板 相关接口 ******************************/ +/*************** end Memory Dashboard APIs ******************************/ -/****************** 用户记忆 相关接口 *******************************/ +/****************** User Memory APIs *******************************/ export const userMemoryListUrl = '/dashboard/end_users' export const getUserMemoryList = () => { return request.get(userMemoryListUrl) } -// 用户记忆-用户记忆总量 +// User Memory - Total end users export const getTotalEndUsers = () => { return request.get(`/dashboard/total_end_users`) } -// 用户记忆-用户详情 +// User Memory - User profile export const getUserProfile = (end_user_id: string) => { return request.get(`/memory/analytics/user_profile`, { end_user_id }) } -// 用户记忆-记忆洞察 +// User Memory - Memory insight export const getMemoryInsightReport = (end_user_id: string) => { return request.get(`/memory-storage/analytics/memory_insight/report`, { end_user_id }) } -// 用户记忆-用户摘要 +// User Memory - User summary export const getUserSummary = (end_user_id: string) => { return request.get(`/memory-storage/analytics/user_summary`, { end_user_id }) } -// 记忆分类 +// Memory classification export const getNodeStatistics = (end_user_id: string) => { return request.get(`/memory-storage/analytics/node_statistics`, { end_user_id }) } -// 基本信息 +// Basic information export const getEndUserProfile = (end_user_id: string) => { return request.get(`/memory-storage/read_end_user/profile`, { end_user_id }) } export const updatedEndUserProfile = (values: EndUser) => { return request.post(`/memory-storage/updated_end_user/profile`, values) } -// 用户记忆-关系网络 +// User Memory - Relationship network export const getMemorySearchEdges = (end_user_id: string) => { return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }) } -// 用户记忆-用户兴趣分布 +// User Memory - User interest distribution export const getHotMemoryTagsByUser = (end_user_id: string) => { return request.get(`/memory/analytics/hot_memory_tags/by_user`, { end_user_id }) } -// 用户记忆-记忆总量 +// User Memory - Total memory count export const getTotalMemoryCountByUser = (end_user_id: string) => { return request.get(`/memory-storage/search`, { end_user_id }) } -// RAG 用户记忆-记忆总量 +// RAG User Memory - Total memory count export const getTotalRagMemoryCountByUser = (end_user_id: string) => { return request.get(`/dashboard/current_user_rag_total_num`, { end_user_id }) } -// RAG 用户记忆-用户摘要 +// RAG User Memory - User summary export const getChunkSummaryTag = (end_user_id: string) => { return request.get(`/dashboard/chunk_summary_tag`, { end_user_id }) } -// RAG 用户记忆-记忆洞察 +// RAG User Memory - Memory insight export const getChunkInsight = (end_user_id: string) => { return request.get(`/dashboard/chunk_insight`, { end_user_id }) } -// RAG 用户记忆-存储内容 +// RAG User Memory - Storage content export const getRagContent = (end_user_id: string) => { return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 }) } -// 情感分布分析 +// Emotion distribution analysis export const getWordCloud = (group_id: string) => { return request.post(`/memory/emotion-memory/wordcloud`, { group_id, limit: 20 }) } -// 高频情绪关键词 +// High-frequency emotion keywords export const getEmotionTags = (group_id: string) => { return request.post(`/memory/emotion-memory/tags`, { group_id, limit: 20 }) } -// 情绪健康指数 +// Emotion health index export const getEmotionHealth = (group_id: string) => { return request.post(`/memory/emotion-memory/health`, { group_id, limit: 20 }) } -// 个性化建议 +// Personalized suggestions export const getEmotionSuggestions = (group_id: string) => { return request.post(`/memory/emotion-memory/suggestions`, { group_id, limit: 20 }) } +export const generateSuggestions = (end_user_id: string) => { + return request.post(`/memory/emotion-memory/generate_suggestions`, { end_user_id }) +} export const analyticsRefresh = (end_user_id: string) => { return request.post('/memory-storage/analytics/generate_cache', { end_user_id }) } -// 遗忘 +// Forgetting stats export const getForgetStats = (group_id: string) => { return request.get(`/memory/forget-memory/stats`, { group_id }) } -// 隐性记忆-偏好 +// Implicit Memory - Preferences export const getImplicitPreferences = (end_user_id: string) => { return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) } -// 隐性记忆-核心特质 +// Implicit Memory - Core traits export const getImplicitPortrait = (end_user_id: string) => { return request.get(`/memory/implicit-memory/portrait/${end_user_id}`) } -// 隐性记忆-兴趣领域分布 +// Implicit Memory - Interest areas distribution export const getImplicitInterestAreas = (end_user_id: string) => { return request.get(`/memory/implicit-memory/interest-areas/${end_user_id}`) } -// 隐性记忆-用户习惯分析 +// Implicit Memory - User habits analysis export const getImplicitHabits = (end_user_id: string) => { return request.get(`/memory/implicit-memory/habits/${end_user_id}`) } -// 短期记忆 +export const generateProfile = (end_user_id: string) => { + return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id }) +} +// Short-term memory export const getShortTerm = (end_user_id: string) => { return request.get(`/memory/short/short_term`, { end_user_id }) } -// 感知记忆-视觉记忆 +// Perceptual Memory - Visual memory export const getPerceptualLastVisual = (end_user: string) => { return request.get(`/memory/perceptual/${end_user}/last_visual`) } -// 感知记忆-音频记忆 +// Perceptual Memory - Audio memory export const getPerceptualLastListen = (end_user: string) => { return request.get(`/memory/perceptual/${end_user}/last_listen`) } -// 感知记忆-文本记忆 +// Perceptual Memory - Text memory export const getPerceptualLastText = (end_user: string) => { return request.get(`/memory/perceptual/${end_user}/last_text`) } -// 感知记忆-感知记忆时间线 +// Perceptual Memory - Perceptual memory timeline export const getPerceptualTimeline = (end_user: string) => { return request.get(`/memory/perceptual/${end_user}/timeline`) } -// 情景记忆-总览 +// Episodic Memory - Overview export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => { return request.post(`/memory/episodic-memory/overview`, data) } export const getEpisodicDetail = (data: { end_user_id: string; summary_id: string; } ) => { return request.post(`/memory/episodic-memory/details`, data) } -// 关系演化 +// Relationship evolution export const getRelationshipEvolution = (data: { id: string; label: string; } ) => { return request.get(`/memory-storage/memory_space/relationship_evolution`, data) } -// 共同记忆时间线 +// Shared memory timeline export const getTimelineMemories = (data: { id: string; label: string; }) => { return request.get(`/memory-storage/memory_space/timeline_memories`, data) } @@ -207,72 +213,72 @@ export const getConversationDetail = (end_user: string, conversation_id: string) export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => { return request.post(`/memory/forget-memory/trigger`, data) } -/*************** end 用户记忆 相关接口 ******************************/ +/*************** end User Memory APIs ******************************/ -/****************** 记忆管理 相关接口 *******************************/ -// 记忆管理-获取所有配置 +/****************** Memory Management APIs *******************************/ +// Memory Management - Get all configurations export const memoryConfigListUrl = '/memory-storage/read_all_config' export const getMemoryConfigList = () => { return request.get(memoryConfigListUrl) } -// 记忆管理-创建配置 +// Memory Management - Create configuration export const createMemoryConfig = (values: MemoryFormData) => { return request.post('/memory-storage/create_config', values) } -// 记忆管理-更新配置 +// Memory Management - Update configuration export const updateMemoryConfig = (values: MemoryFormData) => { return request.post('/memory-storage/update_config', values) } -// 记忆管理-删除配置 +// Memory Management - Delete configuration export const deleteMemoryConfig = (config_id: number) => { return request.delete(`/memory-storage/delete_config?config_id=${config_id}`) } -// 遗忘引擎-获取配置 +// Forgetting Engine - Get configuration export const getMemoryForgetConfig = (config_id: number | string) => { return request.get('/memory/forget-memory/read_config', { config_id }) } -// 遗忘引擎-更新配置 +// Forgetting Engine - Update configuration export const updateMemoryForgetConfig = (values: ForgetConfigForm) => { return request.post('/memory/forget-memory/update_config', values) } -// 记忆萃取引擎-获取配置 +// Memory Extraction Engine - Get configuration export const getMemoryExtractionConfig = (config_id: number | string) => { return request.get('/memory-storage/read_config_extracted', { config_id: config_id }) } -// 记忆萃取引擎-更新配置 +// Memory Extraction Engine - Update configuration export const updateMemoryExtractionConfig = (values: ExtractionConfigForm) => { return request.post('/memory-storage/update_config_extracted', values) } -// 记忆萃取引擎-试运行 +// Memory Extraction Engine - Pilot run export const pilotRunMemoryExtractionConfig = (values: { config_id: number | string; dialogue_text: string; }, onMessage?: (data: SSEMessage[]) => void) => { return handleSSE('/memory-storage/pilot_run', values, onMessage) } -// 情绪引擎-获取配置 +// Emotion Engine - Get configuration export const getMemoryEmotionConfig = (config_id: number | string) => { return request.get('/memory/emotion/read_config', { config_id: config_id }) } -// 情绪引擎-更新配置 +// Emotion Engine - Update configuration export const updateMemoryEmotionConfig = (values: EmotionConfig) => { return request.post('/memory/emotion/updated_config', values) } -// 反思引擎-获取配置 +// Reflection Engine - Get configuration export const getMemoryReflectionConfig = (config_id: number | string) => { return request.get('/memory/reflection/configs', { config_id: config_id }) } -// 反思引擎-更新配置 +// Reflection Engine - Update configuration export const updateMemoryReflectionConfig = (values: SelfReflectionEngineConfig) => { return request.post('/memory/reflection/save', values) } -// 反思引擎-试运行 +// Reflection Engine - Pilot run export const pilotRunMemoryReflectionConfig = (values: { config_id: number | string; language_type: string; }) => { return request.get('/memory/reflection/run', values) } -/*************** end 记忆管理 相关接口 ******************************/ +/*************** end Memory Management APIs ******************************/ -/****************** API参数 相关接口 *******************************/ +/****************** API Parameters APIs *******************************/ export const getMemoryApi = () => { return request.get('/memory/docs/api') } -/*************** end API参数 相关接口 ******************************/ \ No newline at end of file +/*************** end API Parameters APIs ******************************/ \ No newline at end of file diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index eeee6bc9..2e88ad4a 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -658,8 +658,8 @@ export const zh = { priority: '结构化整合', addTool: '添加工具', tool: '工具', + variableConfig: '配置变量' }, - // 角色管理相关翻译 role: { roleManagement: '角色管理', roleId: '角色ID', diff --git a/web/src/store/menu.json b/web/src/store/menu.json index b49788a8..62f6c13c 100644 --- a/web/src/store/menu.json +++ b/web/src/store/menu.json @@ -332,21 +332,6 @@ } ] }, - { - "id": 19, - "parent": 0, - "code": "member", - "label": "成员管理", - "i18nKey": "menu.memberManagement", - "path": "/member", - "enable": true, - "display": true, - "level": 1, - "sort": 0, - "icon": null, - "iconActive": null, - "subs": null - }, { "id": 10, "parent": 0, @@ -377,6 +362,21 @@ "iconActive": null, "subs": null }, + { + "id": 19, + "parent": 0, + "code": "member", + "label": "成员管理", + "i18nKey": "menu.memberManagement", + "path": "/member", + "enable": true, + "display": true, + "level": 1, + "sort": 0, + "icon": null, + "iconActive": null, + "subs": null + }, { "id": 12, "parent": 0, diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index 7688cdd5..e4179e25 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -1,8 +1,47 @@ import { message } from 'antd'; import i18n from '@/i18n' import { cookieUtils } from './request' +import { refreshToken } from '@/api/user' +import { clearAuthData } from './auth' const API_PREFIX = '/api' +// Token refresh state +let isRefreshing = false; +let refreshPromise: Promise | null = null; + +// Refresh token function for SSE +const refreshTokenForSSE = async (): Promise => { + if (isRefreshing && refreshPromise) { + return refreshPromise; + } + + isRefreshing = true; + refreshPromise = (async () => { + try { + const refresh_token = cookieUtils.get('refreshToken'); + if (!refresh_token) { + throw new Error(i18n.t('common.refreshTokenNotExist')); + } + const response: any = await refreshToken(); + const newToken = response.access_token; + cookieUtils.set('authToken', newToken); + return newToken; + } catch (error) { + clearAuthData(); + message.warning(i18n.t('common.loginExpired')); + if (!window.location.hash.includes('#/login')) { + window.location.href = `/#/login`; + } + throw error; + } finally { + isRefreshing = false; + refreshPromise = null; + } + })(); + + return refreshPromise; +}; + export interface SSEMessage { event?: string data?: string | object @@ -66,62 +105,66 @@ function parseDataContent(dataContent: string): string | object { } } +const makeSSERequest = async (url: string, data: any, token: string, config = { headers: {} }) => { + return fetch(`${API_PREFIX}${url}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}`, + ...config.headers, + }, + body: JSON.stringify(data) + }); +}; export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMessage[]) => void, config = { headers: {} }) => { try { - const token = cookieUtils.get('authToken'); - const response = await fetch(`${API_PREFIX}${url}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}`, - ...config.headers, - }, - body: JSON.stringify(data) - }); + let token = cookieUtils.get('authToken'); + let response = await makeSSERequest(url, data, token || '', config); - const { status } = response - - switch(status) { + switch (response.status) { case 401: if (url?.includes('/public')) { return message.warning(i18n.t('common.publicApiCannotRefreshToken')); } - window.location.href = `/#/login`; - break; - default: - if (!response.body) throw new Error('No response body'); - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; // 添加缓冲区来处理不完整的消息 - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - // 处理完整的事件 - const events = buffer.split('\n\n'); - buffer = events.pop() || ''; // 保留最后一个可能不完整的事件 - - for (const event of events) { - if (event.trim() && onMessage) { - onMessage(parseSSEToJSON(event) ?? {}); - } - } - } - - // 处理剩余的缓冲区内容 - if (buffer.trim() && onMessage) { - onMessage(parseSSEToJSON(buffer) ?? {}); + try { + const newToken = await refreshTokenForSSE(); + response = await makeSSERequest(url, data, newToken, config); + } catch (refreshError) { + return; } break; } + if (!response.body) throw new Error('No response body'); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; // 添加缓冲区来处理不完整的消息 + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // 处理完整的事件 + const events = buffer.split('\n\n'); + buffer = events.pop() || ''; // 保留最后一个可能不完整的事件 + + for (const event of events) { + if (event.trim() && onMessage) { + onMessage(parseSSEToJSON(event) ?? {}); + } + } + } + + // 处理剩余的缓冲区内容 + if (buffer.trim() && onMessage) { + onMessage(parseSSEToJSON(buffer) ?? {}); + } } catch (error) { console.error('Request failed:', error); throw error; } -} \ No newline at end of file +}; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 4d410338..28c96ec0 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -13,26 +13,25 @@ import type { Config, ModelConfig, AgentRef, - KnowledgeBase, - KnowledgeConfig, - Variable, MemoryConfig, AiPromptModalRef, Source, - ToolOption + ChatVariableConfigModalRef } from './types' +import type { Variable } from './components/VariableList/types' +import type { KnowledgeConfig } from './components/Knowledge/types' import type { Model } from '@/views/ModelManagement/types' import { getModelList } from '@/api/models'; import { saveAgentConfig } from '@/api/application' -import Knowledge from './components/Knowledge' -import VariableList from './components/VariableList' +import Knowledge from './components/Knowledge/Knowledge' +import VariableList from './components/VariableList/VariableList' import { getApplicationConfig } from '@/api/application' -import { getKnowledgeBaseList } from '@/api/knowledgeBase' import { memoryConfigListUrl } from '@/api/memory' import CustomSelect from '@/components/CustomSelect' import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' -import ToolList from './components/ToolList' +import ToolList from './components/ToolList/ToolList' +import ChatVariableConfigModal from './components/ChatVariableConfigModal'; const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => { return ( @@ -66,7 +65,7 @@ const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[]; ) } -const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string }> = ({ title, desc, name, url }) => { +const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], url: string }> = ({ title, desc, name, url }) => { const { t } = useTranslation(); return ( <> @@ -77,6 +76,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string className="rb:mb-0!" > ((_props, ref) => { const [modelList, setModelList] = useState([]) const [defaultModel, setDefaultModel] = useState(null) const [chatList, setChatList] = useState([]) - const [formData, setFormData] = useState<{ - default_model_config_id?: string, - model_parameters?: Config['model_parameters'], - tools: ToolOption[], - } | null>(null) - const values = Form.useWatch<{ - memoryEnabled: boolean; - memory_content?: string | number; - } & Config>([], form) - - const [knowledgeConfig, setKnowledgeConfig] = useState({ knowledge_bases: [] }) - const [variableList, setVariableList] = useState([]) + const values = Form.useWatch([], form) const [isSave, setIsSave] = useState(false) const initialized = useRef(false) - const [toolList, setToolList] = useState([]) // 初始化完成标记 useEffect(() => { - if (data && values && formData) { + if (data) { initialized.current = true } - }, [data, values, formData]) + }, [data]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [knowledgeConfig]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [variableList]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [formData]) useEffect(() => { if (!initialized.current) return if (isSave) return setIsSave(true) }, [values]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [toolList]) useEffect(() => { getModels() @@ -157,68 +125,19 @@ const Agent = forwardRef((_props, ref) => { setLoading(true) getApplicationConfig(id as string).then(res => { const response = res as Config - setData({ - ...response, - tools: Array.isArray(response.tools) ? response.tools : [] - }) - const { memory, tools } = response + let allTools = Array.isArray(response.tools) ? response.tools : [] form.setFieldsValue({ ...response, - memoryEnabled: memory?.enabled || false, - memory_content: memory?.memory_content ? Number(memory?.memory_content) : undefined, - tools: Array.isArray(tools) ? tools : [] + tools: allTools }) - setFormData({ - default_model_config_id: response.default_model_config_id, - model_parameters: response.model_parameters || {}, - tools: Array.isArray(tools) ? tools : [] + setData({ + ...response, + tools: allTools }) - if (response?.knowledge_retrieval?.knowledge_bases?.length) { - getDefaultKnowledgeList(response) - } - if (response?.tools?.length) { - setToolList(response?.tools) - } }).finally(() => { setLoading(false) }) } - const getDefaultKnowledgeList = (data: Config) => { - if (!data || !data.knowledge_retrieval || !data.knowledge_retrieval?.knowledge_bases?.length) { - return - } - const initialList = [...(data?.knowledge_retrieval?.knowledge_bases || [])] - getKnowledgeBaseList(undefined, { - kb_ids: initialList.map(vo => vo.kb_id).join(','), - page: 1, - pagesize: 100, - }) - .then(res => { - const list = res.items || [] - const knowledge_bases: KnowledgeBase[] = list.map(item => { - const filterItem = initialList.find(vo => vo.kb_id === item.id) - return { - ...item, - ...filterItem - } - }) - setKnowledgeConfig(prev => ({ - ...prev, - knowledge_bases: [...knowledge_bases] - })) - setData((prev) => { - prev = prev as Config - const knowledge_retrieval: KnowledgeConfig = { - ...(prev?.knowledge_retrieval || {}), - knowledge_bases: [...knowledge_bases] - } - return { - ...(prev || {}), - knowledge_retrieval - } - }) - }) - } const refresh = (vo: ModelConfig, type: Source) => { if (type === 'model') { @@ -227,15 +146,7 @@ const Agent = forwardRef((_props, ref) => { default_model_config_id, model_parameters: {...rest} }) - setFormData((prevState) => { - const prev = prevState as Config - return { - ...(prev || {}), - default_model_config_id, - model_parameters: {...rest} - }; - }) - if (default_model_config_id === formData?.default_model_config_id) { + if (default_model_config_id === values?.default_model_config_id) { setChatList([{ label: vo.label || '', model_config_id: default_model_config_id || '', @@ -279,24 +190,20 @@ const Agent = forwardRef((_props, ref) => { // 保存Agent配置 const handleSave = (flag = true) => { if (!isSave || !data) return Promise.resolve() - const { memoryEnabled, memory_content, ...rest } = values - const { knowledge_bases = [], ...knowledgeRest } = knowledgeConfig || {} - - + const { memory, knowledge_retrieval, tools, ...rest } = values + const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {} + const { memory_content } = memory || {} // 从原数据中获取memory的其他必要属性 const originalMemory = data.memory || ({} as MemoryConfig) const params: Config = { ...data, ...rest, - ...(formData || {}), memory: { ...originalMemory, - enabled: memoryEnabled, + ...memory, memory_content: memory_content ? String(memory_content) : '', - max_history: originalMemory.max_history || '', }, - variables: variableList || [], knowledge_retrieval: knowledge_bases.length > 0 ? { ...data.knowledge_retrieval, ...knowledgeRest, @@ -305,7 +212,7 @@ const Agent = forwardRef((_props, ref) => { ...(item.config || {}) })) } as KnowledgeConfig : null, - tools: toolList.map(vo => ({ + tools: tools.map(vo => ({ tool_id: vo.tool_id, operation: vo.operation, enabled: vo.enabled @@ -336,8 +243,8 @@ const Agent = forwardRef((_props, ref) => { modelConfigModalRef.current?.handleOpen('chat') } useEffect(() => { - if (formData?.default_model_config_id && modelList.length > 0) { - const filterValue = modelList.find(item => item.id === formData.default_model_config_id) + if (values?.default_model_config_id && modelList.length > 0) { + const filterValue = modelList.find(item => item.id === values.default_model_config_id) setDefaultModel(filterValue as Model | null) setChatList([{ label: filterValue?.name || '', @@ -346,7 +253,7 @@ const Agent = forwardRef((_props, ref) => { list: [] }]) } - }, [modelList, formData?.default_model_config_id]) + }, [modelList, values?.default_model_config_id]) useImperativeHandle(ref, () => ({ handleSave @@ -358,8 +265,31 @@ const Agent = forwardRef((_props, ref) => { } const updatePrompt = (value: string) => { form.setFieldValue('system_prompt', value) + const variables = value.match(/\{\{([^}]+)\}\}/g)?.map(match => match.slice(2, -2)) || [] + const uniqueVariables = [...new Set(variables)] + const newVariableList: Variable[] = uniqueVariables.map((name, index) => ({ + index, + type: 'text', + name, + display_name: name, + required: false + })) + updateVariableList(newVariableList) } + const updateVariableList = (list: Variable[]) => { + form.setFieldValue('variables', [...list]) + setChatVariables([...list]) + } + const chatVariableConfigModalRef = useRef(null) + const [chatVariables, setChatVariables] = useState([]) + const handleOpenVariableConfig = () => { + chatVariableConfigModalRef.current?.handleOpen(chatVariables) + } + const handleSaveChatVariable = (values: Variable[]) => { + setChatVariables(values) + } + console.log('values', values) return ( <> {loading && } @@ -377,8 +307,9 @@ const Agent = forwardRef((_props, ref) => {
+ + - {/* 提示词 */}
@@ -404,36 +335,31 @@ const Agent = forwardRef((_props, ref) => { - {/* 知识库 */} - + + + {/* 记忆配置 */} - + - {/* 变量配置 */} - + + + {/* 工具配置 */} - + + + @@ -442,6 +368,9 @@ const Agent = forwardRef((_props, ref) => { {t('application.debuggingAndPreview')} + @@ -461,7 +390,7 @@ const Agent = forwardRef((_props, ref) => { @@ -470,6 +399,10 @@ const Agent = forwardRef((_props, ref) => { defaultModel={defaultModel} refresh={updatePrompt} /> + ); }); diff --git a/web/src/views/ApplicationConfig/components/Card.tsx b/web/src/views/ApplicationConfig/components/Card.tsx index 7d9328ea..f414848f 100644 --- a/web/src/views/ApplicationConfig/components/Card.tsx +++ b/web/src/views/ApplicationConfig/components/Card.tsx @@ -3,18 +3,21 @@ import RbCard from '@/components/RbCard/Card' interface CardProps { title?: string | ReactNode; + subTitle?: string | ReactNode; children: ReactNode; extra?: ReactNode; } const Card: FC = ({ title, + subTitle, children, extra, }) => { return ( void; +} + +const ChatVariableConfigModal = forwardRef(({ + refresh, +}, ref) => { + const { t } = useTranslation(); + const [visible, setVisible] = useState(false); + const [form] = Form.useForm<{variables: Variable[]}>(); + const [loading, setLoading] = useState(false) + const [initialValues, setInitialValues] = useState([]) + + // 封装取消方法,添加关闭弹窗逻辑 + const handleClose = () => { + setVisible(false); + form.resetFields(); + setLoading(false) + }; + + const handleOpen = (values: Variable[]) => { + console.log('values', values) + setVisible(true); + form.setFieldsValue({variables: values}) + setInitialValues([...values]) + }; + // 封装保存方法,添加提交逻辑 + const handleSave = () => { + form.validateFields().then((values) => { + refresh([ + ...(values?.variables ?? []), + ]) + handleClose() + }) + } + + // 暴露给父组件的方法 + useImperativeHandle(ref, () => ({ + handleOpen, + handleClose + })); + + console.log(form.getFieldValue('variables')) + + return ( + +
+ + {(fields) => ( + <> + {fields.map(({ name }, index) => { + const field = initialValues[index] + return ( + + { + field.type === 'text' && + } + { + field.type === 'number' && form.setFieldValue(['variables', name, 'value'], value)} /> + } + { + field.type === 'paragraph' && + } + + ) + })} + + )} + +
+
+ ); +}); + +export default ChatVariableConfigModal; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx similarity index 54% rename from web/src/views/ApplicationConfig/components/Knowledge.tsx rename to web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index bc1207e4..1e59f26d 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -2,7 +2,6 @@ import { type FC, useRef, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Space, Button, List } from 'antd' import knowledgeEmpty from '@/assets/images/application/knowledgeEmpty.svg' -import Card from './Card' import type { KnowledgeConfigForm, KnowledgeConfig, @@ -11,14 +10,16 @@ import type { KnowledgeModalRef, KnowledgeConfigModalRef, KnowledgeGlobalConfigModalRef, -} from '../types' +} from './types' import Empty from '@/components/Empty' import KnowledgeListModal from './KnowledgeListModal' import KnowledgeConfigModal from './KnowledgeConfigModal' import KnowledgeGlobalConfigModal from './KnowledgeGlobalConfigModal' import Tag from '@/components/Tag' +import { getKnowledgeBaseList } from '@/api/knowledgeBase' +import Card from '../Card' -const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) => void}> = ({data, onUpdate}) => { +const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfig) => void}> = ({value = {knowledge_bases: []}, onChange}) => { const { t } = useTranslation() const knowledgeModalRef = useRef(null) const knowledgeConfigModalRef = useRef(null) @@ -27,12 +28,31 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) const [editConfig, setEditConfig] = useState({} as KnowledgeConfig) useEffect(() => { - if (data) { - setEditConfig({ ...(data || {}) }) - const knowledge_bases = [...(data.knowledge_bases || [])] - setKnowledgeList(knowledge_bases) + if (value && JSON.stringify(value) !== JSON.stringify(editConfig)) { + setEditConfig({ ...(value || {}) }) + const knowledge_bases = [...(value.knowledge_bases || [])] + + // 检查是否有knowledge_bases缺少name字段 + const basesWithoutName = knowledge_bases.filter(base => !base.name) + if (basesWithoutName.length > 0) { + // 调用接口获取完整的知识库信息 + getKnowledgeBaseList().then(res => { + const fullBases = knowledge_bases.map(base => { + if (!base.name) { + const fullBase = res.items.find((item: any) => item.id === base.kb_id) + return fullBase ? { ...base, ...fullBase } : base + } + return base + }) + setKnowledgeList(fullBases) + }).catch(() => { + setKnowledgeList(knowledge_bases) + }) + } else { + setKnowledgeList(knowledge_bases) + } } - }, [data]) + }, [value]) const handleKnowledgeConfig = () => { knowledgeGlobalConfigModalRef.current?.handleOpen() @@ -43,7 +63,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) const handleDeleteKnowledge = (id: string) => { const list = knowledgeList.filter(item => item.id !== id) setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) @@ -65,7 +85,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) list = [...values as KnowledgeBase[]] } setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) @@ -77,14 +97,14 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) config: {...values as KnowledgeConfigForm} } setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) } else if (type === 'rerankerConfig') { const rerankerValues = values as RerankerConfig setEditConfig(prev => ({ ...prev, ...rerankerValues })) - onUpdate({ + onChange && onChange({ ...editConfig, ...rerankerValues, reranker_id: rerankerValues.rerank_model ? rerankerValues.reranker_id : undefined, @@ -93,55 +113,54 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) } } return ( - handleKnowledgeConfig()}>{t('application.globalConfig')} + + + + } > -
-
{t('application.associatedKnowledgeBase')}
- -
- {knowledgeList.length === 0 ? : ( - -
-
- {item.name} - - {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} - -
{t('application.contains', {include_count: item.doc_num})}
+ renderItem={(item) => { + if (!item.id) return null + return ( + +
+
+ {item.name} + + {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} + +
{t('application.contains', {include_count: item.doc_num})}
+
+ +
handleEditKnowledge(item)} + >
+
handleDeleteKnowledge(item.id)} + >
+
- -
handleEditKnowledge(item)} - >
-
handleDeleteKnowledge(item.id)} - >
-
-
- - )} + + ) + }} /> } - {/* 全局设置 */} - {/* 知识库列表 */} void; } -const retrieveTypes = ['participle', 'semantic', 'hybrid'] +const retrieveTypes: RetrieveType[] = ['participle', 'semantic', 'hybrid'] const KnowledgeConfigModal = forwardRef(({ refresh, @@ -33,8 +33,11 @@ const KnowledgeConfigModal = forwardRef { form.setFieldsValue({ - retrieve_type: retrieveTypes[0], + retrieve_type: data?.config?.retrieve_type || retrieveTypes[0], kb_id: data.id, + top_k: data?.config?.top_k || 5, + similarity_threshold: data?.config?.similarity_threshold || 0.5, + vector_similarity_weight: data?.config?.vector_similarity_weight || 0.5, ...(data || {}), ...(data?.config || {}), }) @@ -62,12 +65,10 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { - const initialValues = Object.keys(values).map(key => { - return { - [key as keyof KnowledgeConfigForm]: (key === 'kb_id' || key === 'retrieve_type') ? values[key] : undefined - } - }) - form.resetFields(initialValues) + const fieldsToReset = Object.keys(values).filter(key => + key !== 'kb_id' && key !== 'retrieve_type' + ) as (keyof KnowledgeConfigForm)[]; + form.resetFields(fieldsToReset); } }, [values?.retrieve_type]) @@ -84,12 +85,12 @@ const KnowledgeConfigModal = forwardRef {data && ( -
-
+
+
{data.name} -
{t('application.contains', {include_count: data.doc_num})}
+
{t('application.contains', {include_count: data.doc_num})}
-
{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}
+
{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}
)} {/* 语义相似度阈值 similarity_threshold */} {values?.retrieve_type === 'semantic' && ( @@ -123,6 +130,7 @@ const KnowledgeConfigModal = forwardRef -
{t('application.globalConfigDesc')}
+
{t('application.globalConfigDesc')}
{/* 结果重排 */} -
-
+
+
{t('application.rerankModel')} -
{t('application.rerankModelDesc')}
+
{t('application.rerankModelDesc')}
@@ -110,7 +110,12 @@ const KnowledgeGlobalConfigModal = forwardRef - + form.setFieldValue('reranker_top_k', value)} + /> } diff --git a/web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx similarity index 88% rename from web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx rename to web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx index 0c7b47b2..f1ebd516 100644 --- a/web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; import { Space, List } from 'antd'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' -import type { KnowledgeModalRef, KnowledgeBase } from '../types' +import type { KnowledgeModalRef, KnowledgeBase } from './types' import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' import RbModal from '@/components/RbModal' import { getKnowledgeBaseList } from '@/api/knowledgeBase' @@ -39,12 +39,13 @@ const KnowledgeListModal = forwardRef(({ setQuery({}) setSelectedIds([]) setSelectedRows([]) - getList() }; useEffect(() => { - getList() - }, [query.keywords]) + if (visible) { + getList() + } + }, [query.keywords, visible]) const getList = () => { getKnowledgeBaseList(undefined, { ...query, @@ -124,15 +125,15 @@ const KnowledgeListModal = forwardRef(({ dataSource={filterList} renderItem={(item: KnowledgeBase) => ( -
handleSelect(item)}> -
+
{item.name} -
{t('application.contains', {include_count: item.doc_num})}
+
{t('application.contains', {include_count: item.doc_num})}
-
{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}
+
{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}
)} diff --git a/web/src/views/ApplicationConfig/components/Knowledge/types.ts b/web/src/views/ApplicationConfig/components/Knowledge/types.ts new file mode 100644 index 00000000..f4f9ed17 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Knowledge/types.ts @@ -0,0 +1,30 @@ +import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' +export interface RerankerConfig { + rerank_model?: boolean | undefined; + reranker_id?: string | undefined; + reranker_top_k?: number | undefined; +} +export type RetrieveType = 'participle' | 'semantic' | 'hybrid' +export interface KnowledgeConfigForm { + kb_id?: string; + similarity_threshold?: number; + vector_similarity_weight?: number; + top_k?: number; + retrieve_type?: RetrieveType; +} +export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm { + config?: KnowledgeConfigForm +} +export interface KnowledgeConfig extends RerankerConfig { + knowledge_bases: KnowledgeBase[]; +} + +export interface KnowledgeConfigModalRef { + handleOpen: (data: KnowledgeBase) => void; +} +export interface KnowledgeGlobalConfigModalRef { + handleOpen: () => void; +} +export interface KnowledgeModalRef { + handleOpen: (config?: KnowledgeConfig[]) => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/ToolList.tsx b/web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx similarity index 93% rename from web/src/views/ApplicationConfig/components/ToolList.tsx rename to web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx index fde7286b..e914d879 100644 --- a/web/src/views/ApplicationConfig/components/ToolList.tsx +++ b/web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx @@ -1,22 +1,22 @@ import { type FC, useRef, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Space, Button, List, Switch } from 'antd' -import Card from './Card' +import Card from '../Card' import type { ToolModalRef, ToolOption -} from '../types' +} from './types' import Empty from '@/components/Empty' import ToolModal from './ToolModal' import { getToolMethods, getToolDetail } from '@/api/tools' -const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => void}> = ({data, onUpdate}) => { +const ToolList: FC<{ value?: ToolOption[]; onChange?: (config: ToolOption[]) => void}> = ({value, onChange}) => { const { t } = useTranslation() const toolModalRef = useRef(null) const [toolList, setToolList] = useState([]) useEffect(() => { - if (data) { - const processedData = data.map(async (item) => { + if (value) { + const processedData = value.map(async (item) => { if (!item.label && item.tool_id) { try { const [toolDetail, methods] = await Promise.all([ @@ -77,7 +77,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi Promise.all(processedData).then(setToolList) } - }, [data]) + }, [value]) const handleAddTool = () => { toolModalRef.current?.handleOpen() @@ -85,12 +85,12 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi const updateTools = (tool: ToolOption) => { const list = [...toolList, tool] setToolList(list) - onUpdate(list) + onChange && onChange(list) } const handleDeleteTool = (index: number) => { const list = toolList.filter((_item, idx) => idx !== index) setToolList([...list]) - onUpdate(list) + onChange && onChange(list) } const handleChangeEnabled = (index: number) => { const list = toolList.map((item, idx) => { @@ -103,7 +103,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi return item }) setToolList([...list]) - onUpdate(list) + onChange && onChange(list) } return ( voi } > - {toolList.length === 0 ? : diff --git a/web/src/views/ApplicationConfig/components/ToolModal.tsx b/web/src/views/ApplicationConfig/components/ToolList/ToolModal.tsx similarity index 100% rename from web/src/views/ApplicationConfig/components/ToolModal.tsx rename to web/src/views/ApplicationConfig/components/ToolList/ToolModal.tsx diff --git a/web/src/views/ApplicationConfig/components/ToolList/types.ts b/web/src/views/ApplicationConfig/components/ToolList/types.ts new file mode 100644 index 00000000..142ffe26 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/ToolList/types.ts @@ -0,0 +1,26 @@ +export interface ToolOption { + value?: string | number | null; + label?: React.ReactNode; + description?: string; + children?: ToolOption[]; + isLeaf?: boolean; + method_id?: string; + operation?: string; + parameters?: Parameter[]; + tool_id?: string; + enabled?: boolean; +} +export interface Parameter { + name: string; + type: string; + description: string; + required: boolean; + default: any; + enum: null | string[]; + minimum: number; + maximum: number; + pattern: null | string; +} +export interface ToolModalRef { + handleOpen: () => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/VariableList.tsx b/web/src/views/ApplicationConfig/components/VariableList.tsx deleted file mode 100644 index fbadf2ea..00000000 --- a/web/src/views/ApplicationConfig/components/VariableList.tsx +++ /dev/null @@ -1,131 +0,0 @@ -import { type FC, useRef, useState, useEffect } from 'react' -import { useTranslation } from 'react-i18next' -import { Space, Button, Switch } from 'antd' -import variablesEmpty from '@/assets/images/application/variablesEmpty.svg' -import Card from './Card' -import Table from '@/components/Table'; -import type { Variable, VariableEditModalRef } from '../types' -import Empty from '@/components/Empty' -import VariableEditModal from './VariableEditModal' - -interface VariableListProps { - data?: Variable[]; - onUpdate: (data: Variable[]) => void; -} -const VariableList: FC = ({data = [], onUpdate}) => { - const { t } = useTranslation() - const variableEditModalRef = useRef(null) - const [variableList, setVariableList] = useState([]) - const [maxIndex, setMaxIndex] = useState(0) - - useEffect(() => { - if (!data || data.length === 0) return - const list = data.map((item, index) => ({ - ...item, - index - })) - setVariableList(list) - onUpdate(list) - setMaxIndex(list.length) - }, [data]) - - const handleAddVariable = () => { - variableEditModalRef.current?.handleOpen() - } - const handleSaveVariable = (value: Variable) => { - if (value.index !== undefined && value.index >= 0) { - const index = variableList.findIndex(item => item.index === value.index) - if (index !== -1) { - const newData = [...variableList] - newData[index] = value - setVariableList([...newData]) - onUpdate([...newData]) - } - } else { - const list = [...variableList, { - index: maxIndex + 1, - ...value - }] - setVariableList(list) - onUpdate([...list]) - setMaxIndex(maxIndex + 1) - } - } - const handleDeleteVariable = (index: number) => { - const list = variableList.filter((_, i) => i !== index) - setVariableList(list) - onUpdate([...list]) - } - return ( - -
-
- {t('application.VariableManagement')} - ({t('application.VariableManagementDesc')}) -
- -
- - {/* List */} - {variableList.length > 0 - ? ( -
- t(`application.${type}`) - }, - { - title: t('application.variableKey'), - dataIndex: 'name', - key: 'name', - }, - { - title: t('application.variableName'), - dataIndex: 'display_name', - key: 'display_name', - }, - { - title: t('application.optional'), - dataIndex: 'required', - key: 'required', - render: (required) => - }, - { - title: t('common.operation'), - key: 'action', - render: (_, record, index: number) => ( - - - - - ), - }, - ]} - initialData={variableList as unknown as Record[]} - emptySize={88} - /> - - ) - : - } - - - ) -} -export default VariableList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx b/web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx similarity index 99% rename from web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx rename to web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx index b1c7450a..4f4f9047 100644 --- a/web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx +++ b/web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { ApiExtensionModalData, ApiExtensionModalRef } from '../types' +import type { ApiExtensionModalData, ApiExtensionModalRef } from './types' import RbModal from '@/components/RbModal' const FormItem = Form.Item; diff --git a/web/src/views/ApplicationConfig/components/VariableEditModal.tsx b/web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx similarity index 96% rename from web/src/views/ApplicationConfig/components/VariableEditModal.tsx rename to web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx index 3efd721c..69e213fb 100644 --- a/web/src/views/ApplicationConfig/components/VariableEditModal.tsx +++ b/web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react'; import { Form, Input, Select, InputNumber, Checkbox, Tag, Divider, Button } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from '../types' +import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from './types' import RbModal from '@/components/RbModal' import SortableList from '@/components/SortableList' import ApiExtensionModal from './ApiExtensionModal' @@ -137,7 +137,14 @@ const VariableEditModal = forwardRef - + { + if (!form.getFieldValue('display_name')) { + form.setFieldValue('display_name', e.target.value) + } + }} + /> {/* 显示名称 */} void; +} +const VariableList: FC = ({value = [], onChange}) => { + const { t } = useTranslation() + const variableEditModalRef = useRef(null) + + const handleAddVariable = () => { + variableEditModalRef.current?.handleOpen() + } + const handleSaveVariable = (variable: Variable) => { + const newList = [...(value || [])] + if (variable.index !== undefined && variable.index >= 0) { + const index = newList.findIndex(item => item.index === variable.index) + if (index !== -1) { + newList[index] = variable + } + } else { + newList.push({ ...variable, index: Date.now() }) + } + onChange?.(newList) + } + return ( + + {t('application.variableConfiguration')} + ({t('application.VariableManagementDesc')}) + } + extra={} + > + + {(fields, { remove }) => { + return ( + <> + {fields.length > 0 ? ( +
+
t(`application.${type}`) + }, + { + title: t('application.variableKey'), + dataIndex: 'name', + key: 'name', + }, + { + title: t('application.variableName'), + dataIndex: 'display_name', + key: 'display_name', + }, + { + title: t('application.optional'), + dataIndex: 'required', + key: 'required', + render: (required) => + }, + { + title: t('common.operation'), + key: 'action', + render: (_, record, index: number) => ( + + + + + ), + }, + ]} + initialData={value as unknown as Record[]} + emptySize={88} + /> + + ) : ( + + )} + + ) + }} + + + + ) +} +export default VariableList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/VariableList/types.ts b/web/src/views/ApplicationConfig/components/VariableList/types.ts new file mode 100644 index 00000000..f262dda1 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/VariableList/types.ts @@ -0,0 +1,28 @@ +export interface Variable { + index?: number; + name: string; + display_name: string; + type: string; + required: boolean; + max_length?: number; + description?: string; + + key?: string; + default_value?: string; + options?: string[]; + api_extension?: string; + hidden?: boolean; + value?: any; +} +export interface VariableEditModalRef { + handleOpen: (values?: Variable) => void; +} + +export interface ApiExtensionModalData { + name: string; + apiEndpoint: string; + apiKey: string; +} +export interface ApiExtensionModalRef { + handleOpen: () => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index 6eb97f22..6f641ebb 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -1,4 +1,6 @@ -import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' +import type { KnowledgeConfig } from './components/Knowledge/types' +import type { Variable } from './components/VariableList/types' +import type { ToolOption } from './components/ToolList/types' import type { ChatItem } from '@/components/Chat/types' import type { GraphRef } from '@/views/Workflow/types'; import type { ApiKey } from '@/views/ApiKeyManagement/types' @@ -14,55 +16,6 @@ export interface ModelConfig { n: number; stop?: string; } - -/*************** 知识库相关 ******************/ -export interface RerankerConfig { - rerank_model?: boolean | undefined; - reranker_id?: string | undefined; - reranker_top_k?: number | undefined; -} -export interface KnowledgeConfigForm { - kb_id?: string; - similarity_threshold?: number; - vector_similarity_weight?: number; - top_k?: number; - retrieve_type?: 'participle' | 'semantic' | 'hybrid'; -} -export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm { - config?: KnowledgeConfigForm -} -export interface KnowledgeConfig extends RerankerConfig { - knowledge_bases: KnowledgeBase[]; -} - -export interface KnowledgeConfigModalRef { - handleOpen: (data: KnowledgeBase) => void; -} -export interface KnowledgeGlobalConfigModalRef { - handleOpen: () => void; -} -/*********** end 知识库相关 ******************/ - -/*************** 变量相关 ******************/ -export interface Variable { - index?: number; - name: string; - display_name: string; - type: string; - required: boolean; - max_length?: number; - description?: string; - - key: string; - default_value?: string; - options?: string[]; - api_extension?: string; - hidden?: boolean; -} -export interface VariableEditModalRef { - handleOpen: (values?: Variable) => void; -} -/*********** end 变量相关 ******************/ export interface MemoryConfig { enabled: boolean; memory_content?: string; @@ -131,17 +84,6 @@ export interface ModelConfigModalData { export interface AiPromptModalRef { handleOpen: () => void; } -export interface KnowledgeModalRef { - handleOpen: (config?: KnowledgeConfig[]) => void; -} -export interface ApiExtensionModalData { - name: string; - apiEndpoint: string; - apiKey: string; -} -export interface ApiExtensionModalRef { - handleOpen: () => void; -} export interface ChatData { label?: string; model_config_id?: string; @@ -206,30 +148,6 @@ export interface AiPromptForm { message?: string; current_prompt?: string; } -export interface ToolModalRef { - handleOpen: () => void; -} - -export interface ToolOption { - value?: string | number | null; - label?: React.ReactNode; - description?: string; - children?: ToolOption[]; - isLeaf?: boolean; - method_id?: string; - operation?: string; - parameters?: Parameter[]; - tool_id?: string; - enabled?: boolean; -} -export interface Parameter { - name: string; - type: string; - description: string; - required: boolean; - default: any; - enum: null | string[]; - minimum: number; - maximum: number; - pattern: null | string; +export interface ChatVariableConfigModalRef { + handleOpen: (values: Variable[]) => void; } \ No newline at end of file diff --git a/web/src/views/Index/components/VersionCard.tsx b/web/src/views/Index/components/VersionCard.tsx index 5cd94348..b299ad29 100644 --- a/web/src/views/Index/components/VersionCard.tsx +++ b/web/src/views/Index/components/VersionCard.tsx @@ -4,7 +4,7 @@ * @Author: yujiangping * @Date: 2026-01-12 16:34:59 * @LastEditors: yujiangping - * @LastEditTime: 2026-01-16 13:00:22 + * @LastEditTime: 2026-01-16 15:38:35 */ import React, { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -44,7 +44,7 @@ const GuideCard: React.FC = () => { {versionInfo?.version} -
+
{versionInfo && (() => { const introduction = getIntroduction(); return introduction ? (<> diff --git a/web/src/views/UserMemoryDetail/components/EmotionTags.tsx b/web/src/views/UserMemoryDetail/components/EmotionTags.tsx index 5fc8f382..a6e6385a 100644 --- a/web/src/views/UserMemoryDetail/components/EmotionTags.tsx +++ b/web/src/views/UserMemoryDetail/components/EmotionTags.tsx @@ -121,7 +121,7 @@ const EmotionTags: FC = () => { })}
- : + : } ) diff --git a/web/src/views/UserMemoryDetail/components/Habits.tsx b/web/src/views/UserMemoryDetail/components/Habits.tsx index 93ef817f..76756f9f 100644 --- a/web/src/views/UserMemoryDetail/components/Habits.tsx +++ b/web/src/views/UserMemoryDetail/components/Habits.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState } from 'react' +import { useEffect, useState, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' import { Skeleton, Space, Progress } from 'antd'; @@ -20,7 +20,7 @@ interface HabitsItem { specific_examples: string[]; } -const Habits: FC = () => { +const Habits = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const [loading, setLoading] = useState(false) @@ -43,6 +43,9 @@ const Habits: FC = () => { setLoading(false) }) } + useImperativeHandle(ref, () => ({ + handleRefresh: getData + })); return ( <> @@ -80,5 +83,5 @@ const Habits: FC = () => { ) -} +}) export default Habits \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/components/InterestAreas.tsx b/web/src/views/UserMemoryDetail/components/InterestAreas.tsx index 6f74d255..df1a75c6 100644 --- a/web/src/views/UserMemoryDetail/components/InterestAreas.tsx +++ b/web/src/views/UserMemoryDetail/components/InterestAreas.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState } from 'react' +import { useEffect, useState, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' import { Skeleton, Progress } from 'antd'; @@ -23,7 +23,7 @@ interface InterestAreasItem { art: Item; } -const InterestAreas: FC = () => { +const InterestAreas = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const [loading, setLoading] = useState(false) @@ -47,6 +47,9 @@ const InterestAreas: FC = () => { }) } + useImperativeHandle(ref, () => ({ + handleRefresh: getData + })); return ( { } ) -} +}) export default InterestAreas \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/components/Portrait.tsx b/web/src/views/UserMemoryDetail/components/Portrait.tsx index 3164ae06..a14fee0d 100644 --- a/web/src/views/UserMemoryDetail/components/Portrait.tsx +++ b/web/src/views/UserMemoryDetail/components/Portrait.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState } from 'react' +import { useEffect, useState, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' import { Skeleton, Progress } from 'antd'; @@ -25,7 +25,7 @@ interface PortraitItem { literature: Item; } -const Portrait: FC = () => { +const Portrait = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const [loading, setLoading] = useState(false) @@ -49,6 +49,9 @@ const Portrait: FC = () => { }) } + useImperativeHandle(ref, () => ({ + handleRefresh: getData + })); return ( { } ) -} +}) export default Portrait \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/components/Preferences.tsx b/web/src/views/UserMemoryDetail/components/Preferences.tsx index 0644197f..4b8d1766 100644 --- a/web/src/views/UserMemoryDetail/components/Preferences.tsx +++ b/web/src/views/UserMemoryDetail/components/Preferences.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState, useRef, useMemo } from 'react' +import { useEffect, useState, useRef, useMemo, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' import { Row, Col, Skeleton } from 'antd' @@ -31,7 +31,7 @@ const generateCategoryColors = (categories: string[]) => { return colors } -const Preferences: FC = () => { +const Preferences = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const chartRef = useRef(null) @@ -138,6 +138,9 @@ const Preferences: FC = () => { return selectedWord !== null && data[selectedWord].tag_name ? <>{data[selectedWord].tag_name}{t('implicitDetail.preferencesDetail')} : '' }, [selectedWord, data, t]) + useImperativeHandle(ref, () => ({ + handleRefresh: getData + })); return ( <>
{t('forgetDetail.overviewTitle')}
@@ -184,6 +187,6 @@ const Preferences: FC = () => { ) -} +}) export default Preferences \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/components/Suggestions.tsx b/web/src/views/UserMemoryDetail/components/Suggestions.tsx index 35fde91f..c2c8ca8b 100644 --- a/web/src/views/UserMemoryDetail/components/Suggestions.tsx +++ b/web/src/views/UserMemoryDetail/components/Suggestions.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState } from 'react' +import { useEffect, useState, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' @@ -18,7 +18,7 @@ interface Suggestions { actionable_steps: string[]; }>; } -const Suggestions: FC = () => { +const Suggestions = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const [suggestions, setSuggestions] = useState(null) @@ -37,6 +37,9 @@ const Suggestions: FC = () => { }) } + useImperativeHandle(ref, () => ({ + handleRefresh: getSuggestionData + })); return ( { } ) -} +}) export default Suggestions \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx index ef23463a..d79407da 100644 --- a/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx @@ -1,34 +1,57 @@ -import { type FC } from 'react' +import { forwardRef, useImperativeHandle, useRef } from 'react' import { useTranslation } from 'react-i18next' import { Row, Col } from 'antd' +import { useParams } from 'react-router-dom' import Preferences from '../components/Preferences' import Portrait from '../components/Portrait' import InterestAreas from '../components/InterestAreas' import Habits from '../components/Habits' +import { + generateProfile, +} from '@/api/memory' -const ImplicitDetail: FC = () => { +const ImplicitDetail = forwardRef<{ handleRefresh: () => void; }>((_props, ref) => { const { t } = useTranslation() + const { id } = useParams() + const preferencesRef = useRef<{ handleRefresh: () => void; }>(null) + const portraitRef = useRef<{ handleRefresh: () => void; }>(null) + const interestAreasRef = useRef<{ handleRefresh: () => void; }>(null) + const habitsRef = useRef<{ handleRefresh: () => void; }>(null) + + const handleRefresh = () => { + if (!id) return + generateProfile(id) + .then(() => { + preferencesRef.current?.handleRefresh() + portraitRef.current?.handleRefresh() + interestAreasRef.current?.handleRefresh() + habitsRef.current?.handleRefresh() + }) + } + useImperativeHandle(ref, () => ({ + handleRefresh + })); return (
{t('implicitDetail.title')}
- +
{t('implicitDetail.portraitTitle')}
{t('implicitDetail.portraitSubTitle')}
- + - + - + ) -} +}) export default ImplicitDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx index e6ddfd20..6515263e 100644 --- a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx @@ -1,13 +1,27 @@ -import { type FC } from 'react' +import { forwardRef, useImperativeHandle, useRef } from 'react' import { Row, Col, Space } from 'antd'; +import { useParams } from 'react-router-dom' import WordCloud from '../components/WordCloud' import EmotionTags from '../components/EmotionTags' import Health from '../components/Health' import Suggestions from '../components/Suggestions' +import { generateSuggestions } from '@/api/memory' -const StatementDetail: FC = () => { +const StatementDetail = forwardRef((_props, ref) => { + const { id } = useParams() + const suggestionsRef = useRef<{ handleRefresh: () => void; }>(null) + const handleRefresh = () => { + if (!id) return + generateSuggestions(id) + .then(() => { + suggestionsRef.current?.handleRefresh() + }) + } + useImperativeHandle(ref, () => ({ + handleRefresh + })); return ( @@ -18,10 +32,10 @@ const StatementDetail: FC = () => { - + ) -} +}) export default StatementDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/index.tsx b/web/src/views/UserMemoryDetail/pages/index.tsx index f225b1f0..16004edc 100644 --- a/web/src/views/UserMemoryDetail/pages/index.tsx +++ b/web/src/views/UserMemoryDetail/pages/index.tsx @@ -24,6 +24,8 @@ const Detail: FC = () => { const navigate = useNavigate() const [name, setName] = useState('') const forgetDetailRef = useRef<{ handleRefresh: () => void }>(null) + const statementDetailRef = useRef<{ handleRefresh: () => void }>(null) + const implicitDetailRef = useRef<{ handleRefresh: () => void }>(null) useEffect(() => { if (!id) return @@ -45,7 +47,17 @@ const Detail: FC = () => { navigate(`/user-memory/detail/${id}/${key}`, { replace: true }) } const handleRefresh = () => { - forgetDetailRef.current?.handleRefresh() + switch(type) { + case 'FORGET_MEMORY': + forgetDetailRef.current?.handleRefresh() + break; + case 'EMOTIONAL_MEMORY': + statementDetailRef.current?.handleRefresh() + break + case 'IMPLICIT_MEMORY': + implicitDetailRef.current?.handleRefresh() + break + } } if (type === 'GRAPH') { @@ -67,16 +79,16 @@ const Detail: FC = () => { } - extra={type === 'FORGET_MEMORY' && + extra={['FORGET_MEMORY', 'EMOTIONAL_MEMORY', 'IMPLICIT_MEMORY'].includes(type as string) && } />
- {type === 'EMOTIONAL_MEMORY' && } + {type === 'EMOTIONAL_MEMORY' && } {type === 'FORGET_MEMORY' && } - {type === 'IMPLICIT_MEMORY' && } + {type === 'IMPLICIT_MEMORY' && } {type === 'SHORT_TERM_MEMORY' && } {type === 'PERCEPTUAL_MEMORY' && } {type === 'EPISODIC_MEMORY' && } diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index a9dc39c3..f8a5a6bc 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -586,6 +586,77 @@ export const useWorkflowGraph = ({ graphRef.current.resize(containerRef.current.offsetWidth, containerRef.current.offsetHeight); } }; + + const nodeChangePosition = ({ node, options }: { node: Node; options: { skipParentHandler?: boolean } }) => { + const embedPadding = 50; // Define the embed padding constant + if (options.skipParentHandler) { + return + } + + const children = node.getChildren() + if (children && children.length) { + node.prop('originPosition', node.getPosition()) + } + + const parent = node.getParent() + if (parent && parent.isNode()) { + let originSize = parent.prop('originSize') + if (originSize == null) { + originSize = parent.getSize() + parent.prop('originSize', originSize) + } + + let originPosition = parent.prop('originPosition') + if (originPosition == null) { + originPosition = parent.getPosition() + parent.prop('originPosition', originPosition) + } + + let x = originPosition.x + let y = originPosition.y + let cornerX = originPosition.x + originSize.width + let cornerY = originPosition.y + originSize.height + let hasChange = false + + const children = parent.getChildren() + if (children) { + children.forEach((child) => { + const bbox = child.getBBox().inflate(embedPadding) + const corner = bbox.getCorner() + + if (bbox.x < x) { + x = bbox.x + hasChange = true + } + + if (bbox.y < y) { + y = bbox.y + hasChange = true + } + + if (corner.x > cornerX) { + cornerX = corner.x + hasChange = true + } + + if (corner.y > cornerY) { + cornerY = corner.y + hasChange = true + } + }) + } + + if (hasChange) { + parent.prop( + { + position: { x, y }, + size: { width: cornerX - x, height: cornerY - y }, + }, + { skipParentHandler: true }, + ) + } + } + } // 初始化 const init = () => { @@ -674,10 +745,7 @@ export const useWorkflowGraph = ({ }, }, embedding: { - enabled: true, - validate (this) { - return false - } + enabled: false, }, translating: { restrict(view) { @@ -693,6 +761,17 @@ export const useWorkflowGraph = ({ return null }, }, + highlighting: { + embedding: { + name: 'stroke', + args: { + padding: -1, + attrs: { + stroke: '#73d13d', + }, + }, + }, + }, }); // 使用插件 setupPlugins();