diff --git a/api/app/base/type.py b/api/app/base/type.py new file mode 100644 index 00000000..fecbe13a --- /dev/null +++ b/api/app/base/type.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, Field +from sqlalchemy import TypeDecorator, JSON + + +class PydanticType(TypeDecorator): + impl = JSON + + def __init__(self, pydantic_model: type[BaseModel]): + super().__init__() + self.model = pydantic_model + + def process_bind_param(self, value, dialect): + # 入库:Model -> dict + if value is None: + return None + if isinstance(value, self.model): + return value.dict() + return value # 已经是 dict 也放行 + + def process_result_value(self, value, dialect): + # 出库:dict -> Model + if value is None: + return None + # return self.model.parse_obj(value) # pydantic v1 + return self.model.model_validate(value) # pydantic v2 diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 0b07d0c9..a45c701f 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -4,39 +4,45 @@ 认证方式: JWT Token """ from fastapi import APIRouter + from . import ( - model_controller, - task_controller, - test_controller, - user_controller, - auth_controller, - workspace_controller, - setup_controller, - file_controller, - document_controller, - knowledge_controller, - chunk_controller, - knowledgeshare_controller, + api_key_controller, app_controller, - upload_controller, + auth_controller, + chunk_controller, + document_controller, + emotion_config_controller, + emotion_controller, + file_controller, + home_page_controller, + implicit_memory_controller, + knowledge_controller, + knowledgeshare_controller, memory_agent_controller, memory_dashboard_controller, - memory_storage_controller, - memory_dashboard_controller, + memory_forget_controller, memory_reflection_controller, - api_key_controller, - release_share_controller, - public_share_controller, + memory_short_term_controller, + memory_storage_controller, + model_controller, multi_agent_controller, - workflow_controller, - emotion_controller, - emotion_config_controller, prompt_optimizer_controller, + public_share_controller, + release_share_controller, + setup_controller, + task_controller, + test_controller, tool_controller, + upload_controller, + user_controller, + user_memory_controllers, + workflow_controller, + workspace_controller, memory_forget_controller, home_page_controller, + memory_perceptual_controller, + memory_working_controller, ) -from . import user_memory_controllers # 创建管理端 API 路由器 manager_router = APIRouter() @@ -71,8 +77,12 @@ manager_router.include_router(emotion_controller.router) manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) +manager_router.include_router(memory_short_term_controller.router) manager_router.include_router(tool_controller.router) manager_router.include_router(memory_forget_controller.router) manager_router.include_router(home_page_controller.router) +manager_router.include_router(implicit_memory_controller.router) +manager_router.include_router(memory_perceptual_controller.router) +manager_router.include_router(memory_working_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 698f061d..2300f148 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -48,6 +48,7 @@ def list_apps( include_shared: bool = True, page: int = 1, pagesize: int = 10, + ids: Optional[str] = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -55,8 +56,19 @@ def list_apps( - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 + - 当提供 ids 参数时,按逗号分割获取指定应用,不分页 """ workspace_id = current_user.current_workspace_id + service = app_service.AppService(db) + + # 当 ids 存在且不为 None 时,根据 ids 获取应用 + if ids is not None: + app_ids = [id.strip() for id in ids.split(',') if id.strip()] + items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) + items = [service._convert_to_schema(app, workspace_id) for app in items_orm] + return success(data=items) + + # 正常分页查询 items_orm, total = app_service.list_apps( db, workspace_id=workspace_id, @@ -69,8 +81,6 @@ def list_apps( pagesize=pagesize, ) - # 使用 AppService 的转换方法来设置 is_shared 字段 - service = app_service.AppService(db) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) return success(data=PageData(page=meta, items=items)) @@ -506,7 +516,7 @@ async def draft_run( multi_agent_request = MultiAgentRunRequest( message=payload.message, conversation_id=payload.conversation_id, - user_id=payload.user_id, + user_id=payload.user_id or str(current_user.id), variables=payload.variables or {}, use_llm_routing=True # 默认启用 LLM 路由 ) @@ -728,9 +738,23 @@ async def draft_run_compare( from app.core.exceptions import ResourceNotFoundException raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) + # 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典 + agent_model_params = agent_cfg.model_parameters + if hasattr(agent_model_params, 'model_dump'): + agent_model_params = agent_model_params.model_dump() + elif not isinstance(agent_model_params, dict): + agent_model_params = {} + + # 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典 + item_model_params = model_item.model_parameters + if hasattr(item_model_params, 'model_dump'): + item_model_params = item_model_params.model_dump() + elif not isinstance(item_model_params, dict): + item_model_params = {} + merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) + **(agent_model_params or {}), + **(item_model_params or {}) } model_configs.append({ diff --git a/api/app/controllers/home_page_controller.py b/api/app/controllers/home_page_controller.py index 6665eec1..77db9d8f 100644 --- a/api/app/controllers/home_page_controller.py +++ b/api/app/controllers/home_page_controller.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session +from app.core.config import settings from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user @@ -26,4 +27,12 @@ def get_workspace_list( ): """获取工作空间列表""" workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id) - return success(data=workspace_list, msg="工作空间列表获取成功") \ No newline at end of file + return success(data=workspace_list, msg="工作空间列表获取成功") + +@router.get("/version", response_model=ApiResponse) +def get_system_version(): + """获取系统版本号+说明""" + return success(data={ + "version": settings.SYSTEM_VERSION, + "introduction": settings.SYSTEM_INTRODUCTION + }, msg="系统版本获取成功") \ No newline at end of file diff --git a/api/app/controllers/implicit_memory_controller.py b/api/app/controllers/implicit_memory_controller.py new file mode 100644 index 00000000..6ef39929 --- /dev/null +++ b/api/app/controllers/implicit_memory_controller.py @@ -0,0 +1,312 @@ +from datetime import datetime +from typing import Optional + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import ( + cur_workspace_access_guard, + get_current_user, +) +from app.models.user_model import User +from app.schemas.response_schema import ApiResponse +from app.services.implicit_memory_service import ImplicitMemoryService +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/implicit-memory", + tags=["Implicit Memory"], +) + + +def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict: + """ + Centralized error handling for implicit memory operations. + + Args: + e: The exception that occurred + operation: Description of the operation that failed + user_id: Optional user ID for logging context + + Returns: + Standardized error response + """ + error_context = f"user_id={user_id}" if user_id else "unknown user" + + if isinstance(e, ValueError): + if "user" in str(e).lower() and "not found" in str(e).lower(): + api_logger.warning(f"Invalid user ID for {operation}: {error_context}") + return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e)) + elif "insufficient" in str(e).lower() or "no data" in str(e).lower(): + api_logger.warning(f"Insufficient data for {operation}: {error_context}") + return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e)) + else: + api_logger.warning(f"Invalid parameters for {operation}: {error_context}") + return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e)) + + elif isinstance(e, KeyError): + api_logger.warning(f"Missing required data for {operation}: {error_context}") + return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e)) + + elif isinstance(e, (ConnectionError, TimeoutError)): + api_logger.error(f"Service unavailable for {operation}: {error_context}") + return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e)) + + elif "analysis" in str(e).lower() or "llm" in str(e).lower(): + api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True) + return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e)) + + elif "storage" in str(e).lower() or "database" in str(e).lower(): + api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True) + return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e)) + + else: + api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e)) + + +def validate_user_id(user_id: str) -> None: + """ + Validate user ID format and constraints. + + Args: + user_id: User ID to validate + + Raises: + ValueError: If user ID is invalid + """ + if not user_id or not user_id.strip(): + raise ValueError("User ID cannot be empty") + + if len(user_id.strip()) < 1: + raise ValueError("User ID is too short") + + +def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None: + """ + Validate date range parameters. + + Args: + start_date: Start date + end_date: End date + + Raises: + ValueError: If date range is invalid + """ + if (start_date and not end_date) or (end_date and not start_date): + raise ValueError("Both start_date and end_date must be provided together") + + if start_date and end_date and start_date >= end_date: + raise ValueError("start_date must be before end_date") + + if start_date and start_date > datetime.now(): + raise ValueError("start_date cannot be in the future") + + +def validate_confidence_threshold(threshold: float) -> None: + """ + Validate confidence threshold parameter. + + Args: + threshold: Confidence threshold to validate + + Raises: + ValueError: If threshold is invalid + """ + if not 0.0 <= threshold <= 1.0: + raise ValueError("confidence_threshold must be between 0.0 and 1.0") + + +@router.get("/preferences/{user_id}", response_model=ApiResponse) +@cur_workspace_access_guard() +async def get_preference_tags( + user_id: str, + confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"), + tag_category: Optional[str] = Query(None, description="Filter by tag category"), + start_date: Optional[datetime] = Query(None, description="Filter start date"), + end_date: Optional[datetime] = Query(None, description="Filter end date"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +) -> ApiResponse: + """ + Get user preference tags with filtering options. + + Args: + user_id: Target user ID + confidence_threshold: Minimum confidence score (0.0-1.0) + tag_category: Optional category filter + start_date: Optional start date filter + end_date: Optional end date filter + + Returns: + List of preference tags matching the filters + """ + api_logger.info(f"Preference tags requested for user: {user_id}") + + try: + # Validate inputs + validate_user_id(user_id) + validate_confidence_threshold(confidence_threshold) + validate_date_range(start_date, end_date) + + # Create service with user-specific config + service = ImplicitMemoryService(db=db, end_user_id=user_id) + + # Build date range + date_range = None + if start_date and end_date: + from app.schemas.implicit_memory_schema import DateRange + date_range = DateRange(start_date=start_date, end_date=end_date) + + # Get preference tags + tags = await service.get_preference_tags( + user_id=user_id, + confidence_threshold=confidence_threshold, + tag_category=tag_category, + date_range=date_range + ) + + api_logger.info(f"Retrieved {len(tags)} preference tags for user: {user_id}") + return success(data=[tag.model_dump(mode='json') for tag in tags], msg="偏好标签获取成功") + + except Exception as e: + return handle_implicit_memory_error(e, "偏好标签获取", user_id) + + +@router.get("/portrait/{user_id}", response_model=ApiResponse) +@cur_workspace_access_guard() +async def get_dimension_portrait( + user_id: str, + include_history: bool = Query(False, description="Include historical trends"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +) -> ApiResponse: + """ + Get user's four-dimension personality portrait. + + Args: + user_id: Target user ID + include_history: Whether to include historical trend data + + Returns: + Four-dimension personality portrait with scores and evidence + """ + api_logger.info(f"Dimension portrait requested for user: {user_id}") + + try: + # Validate inputs + validate_user_id(user_id) + + # Create service with user-specific config + service = ImplicitMemoryService(db=db, end_user_id=user_id) + + portrait = await service.get_dimension_portrait( + user_id=user_id, + include_history=include_history + ) + + api_logger.info(f"Dimension portrait retrieved for user: {user_id}") + return success(data=portrait.model_dump(mode='json'), msg="四维画像获取成功") + + except Exception as e: + return handle_implicit_memory_error(e, "四维画像获取", user_id) + + +@router.get("/interest-areas/{user_id}", response_model=ApiResponse) +@cur_workspace_access_guard() +async def get_interest_area_distribution( + user_id: str, + include_trends: bool = Query(False, description="Include trend analysis"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +) -> ApiResponse: + """ + Get user's interest area distribution across four areas. + + Args: + user_id: Target user ID + include_trends: Whether to include trend analysis data + + Returns: + Interest area distribution with percentages and evidence + """ + api_logger.info(f"Interest area distribution requested for user: {user_id}") + + try: + # Validate inputs + validate_user_id(user_id) + + # Create service with user-specific config + service = ImplicitMemoryService(db=db, end_user_id=user_id) + + distribution = await service.get_interest_area_distribution( + user_id=user_id, + include_trends=include_trends + ) + + api_logger.info(f"Interest area distribution retrieved for user: {user_id}") + return success(data=distribution.model_dump(mode='json'), msg="兴趣领域分布获取成功") + + except Exception as e: + return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id) + + +@router.get("/habits/{user_id}", response_model=ApiResponse) +@cur_workspace_access_guard() +async def get_behavior_habits( + user_id: str, + confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"), + frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"), + time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +) -> ApiResponse: + """ + Get user's behavioral habits with filtering options. + + Args: + user_id: Target user ID + confidence_level: Filter by confidence level (high, medium, low) + frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered) + time_period: Filter by time period (current, past) + + Returns: + List of behavioral habits matching the filters + """ + api_logger.info(f"Behavior habits requested for user: {user_id}") + + try: + # Validate inputs + validate_user_id(user_id) + + # Convert string confidence level to numerical + numerical_confidence = None + if confidence_level: + confidence_mapping = { + "high": 85, + "medium": 50, + "low": 20 + } + numerical_confidence = confidence_mapping.get(confidence_level.lower()) + + # Create service with user-specific config + service = ImplicitMemoryService(db=db, end_user_id=user_id) + + habits = await service.get_behavior_habits( + user_id=user_id, + confidence_level=numerical_confidence, + frequency_pattern=frequency_pattern, + time_period=time_period + ) + + api_logger.info(f"Retrieved {len(habits)} behavior habits for user: {user_id}") + return success(data=[habit.model_dump(mode='json') for habit in habits], msg="行为习惯获取成功") + + except Exception as e: + return handle_implicit_memory_error(e, "行为习惯获取", user_id) + + diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py index d4a76f6f..705445fd 100644 --- a/api/app/controllers/memory_forget_controller.py +++ b/api/app/controllers/memory_forget_controller.py @@ -76,9 +76,28 @@ async def trigger_forgetting_cycle( api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 通过 group_id 获取关联的 config_id + try: + from app.services.memory_agent_service import get_end_user_connected_config + + connected_config = get_end_user_connected_config(payload.group_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id is None: + api_logger.warning(f"终端用户 {payload.group_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {payload.group_id} 未关联记忆配置", "memory_config_id is None") + + api_logger.debug(f"通过 group_id={payload.group_id} 获取到 config_id={config_id}") + except ValueError as e: + api_logger.warning(f"获取终端用户配置失败: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + except Exception as e: + api_logger.error(f"获取终端用户配置时发生错误: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e)) + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: " - f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, " + f"group_id={payload.group_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, " f"min_days={payload.min_days_since_access}" ) @@ -89,7 +108,7 @@ async def trigger_forgetting_cycle( group_id=payload.group_id, max_merge_batch_size=payload.max_merge_batch_size, min_days_since_access=payload.min_days_since_access, - config_id=payload.config_id + config_id=config_id ) # 构建响应 @@ -217,7 +236,6 @@ async def update_forgetting_config( @router.get("/stats", response_model=ApiResponse) async def get_forgetting_stats( group_id: Optional[str] = None, - config_id: Optional[int] = None, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): @@ -227,8 +245,7 @@ async def get_forgetting_stats( 返回知识层节点统计、激活值分布等信息。 Args: - group_id: 组ID(可选) - config_id: 配置ID(可选,用于获取遗忘阈值) + group_id: 组ID(即 end_user_id,可选) current_user: 当前用户 db: 数据库会话 @@ -242,6 +259,27 @@ async def get_forgetting_stats( api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + # 如果提供了 group_id,通过它获取 config_id + config_id = None + if group_id: + try: + from app.services.memory_agent_service import get_end_user_connected_config + + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id is None: + api_logger.warning(f"终端用户 {group_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None") + + api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}") + except ValueError as e: + api_logger.warning(f"获取终端用户配置失败: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + except Exception as e: + api_logger.error(f"获取终端用户配置时发生错误: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e)) + api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: " f"group_id={group_id}, config_id={config_id}" diff --git a/api/app/controllers/memory_perceptual_controller.py b/api/app/controllers/memory_perceptual_controller.py new file mode 100644 index 00000000..5154c763 --- /dev/null +++ b/api/app/controllers/memory_perceptual_controller.py @@ -0,0 +1,255 @@ +import uuid +from typing import Optional + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.models.memory_perceptual_model import PerceptualType +from app.schemas.memory_perceptual_schema import ( + PerceptualQuerySchema, + PerceptualFilter +) +from app.schemas.response_schema import ApiResponse +from app.services.memory_perceptual_service import MemoryPerceptualService + +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/perceptual", + tags=["Perceptual Memory System"], + dependencies=[Depends(get_current_user)] +) + + +@router.get("/{group_id}/count", response_model=ApiResponse) +def get_memory_count( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve perceptual memory statistics for a user group. + + Args: + group_id: ID of the user group (usually end_user_id in this context) + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Response containing memory count statistics + """ + api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + count_stats = service.get_memory_count(group_id) + + api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}") + + return success( + data=count_stats, + msg="Memory statistics retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch memory statistics", + ) + + +@router.get("/{group_id}/last_visual", response_model=ApiResponse) +def get_last_visual_memory( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent VISION-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest visual memory + """ + api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + visual_memory = service.get_latest_visual_memory(group_id) + + if visual_memory is None: + api_logger.info(f"No visual memory found: group_id={group_id}") + return success( + data=None, + msg="No visual memory available" + ) + + api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}") + + return success( + data=visual_memory, + msg="Latest visual memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest visual memory", + ) + + +@router.get("/{group_id}/last_listen", response_model=ApiResponse) +def get_last_memory_listen( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent AUDIO-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest audio memory + """ + api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}") + + try: + service = MemoryPerceptualService(db) + audio_memory = service.get_latest_audio_memory(group_id) + + if audio_memory is None: + api_logger.info(f"No audio memory found: group_id={group_id}") + return success( + data=None, + msg="No audio memory available" + ) + + api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}") + + return success( + data=audio_memory, + msg="Latest audio memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest audio memory", + ) + + +@router.get("/{group_id}/last_text", response_model=ApiResponse) +def get_last_text_memory( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve the most recent TEXT-type memory for a user. + + Args: + group_id: ID of the user group + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Metadata of the latest text memory + """ + api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}") + + try: + # 调用服务层获取最近的文本记忆 + service = MemoryPerceptualService(db) + text_memory = service.get_latest_text_memory(group_id) + + if text_memory is None: + api_logger.info(f"No text memory found: group_id={group_id}") + return success( + data=None, + msg="No text memory available" + ) + + api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}") + + return success( + data=text_memory, + msg="Latest text memory retrieved successfully" + ) + + except Exception as e: + api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}") + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch latest text memory", + ) + + +@router.get("/{group_id}/timeline", response_model=ApiResponse) +def get_memory_time_line( + group_id: uuid.UUID, + perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"), + page: int = Query(1, ge=1, description="页码"), + page_size: int = Query(10, ge=1, le=100, description="每页大小"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve a timeline of perceptual memories for a user group. + + Args: + group_id: ID of the user group + perceptual_type: Optional filter for perceptual type + page: Page number for pagination + page_size: Number of items per page + current_user: Current authenticated user + db: Database session + + Returns: + ApiResponse: Timeline data of perceptual memories + """ + api_logger.info( + f"Fetching perceptual memory timeline: user={current_user.username}, " + f"group_id={group_id}, type={perceptual_type}, page={page}" + ) + + try: + query = PerceptualQuerySchema( + filter=PerceptualFilter(type=perceptual_type), + page=page, + page_size=page_size + ) + + service = MemoryPerceptualService(db) + timeline_data = service.get_time_line(group_id, query) + + api_logger.info( + f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, " + f"returned={len(timeline_data.memories)}" + ) + + return success( + data=timeline_data.model_dump(), + msg="Perceptual memory timeline retrieved successfully" + ) + + except Exception as e: + api_logger.error( + f"Failed to fetch perceptual memory timeline: group_id={group_id}, " + f"error={str(e)}" + ) + return fail( + code=BizCode.INTERNAL_ERROR, + msg="Failed to fetch perceptual memory timeline", + ) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py new file mode 100644 index 00000000..64991f4d --- /dev/null +++ b/api/app/controllers/memory_short_term_controller.py @@ -0,0 +1,43 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from app.core.logging_config import get_api_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User + +from app.services.memory_storage_service import search_entity +from app.services.memory_short_service import ShortService,LongService +from dotenv import load_dotenv +from sqlalchemy.orm import Session +from typing import Optional +load_dotenv() +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/short", + tags=["Memory"], +) +@router.get("/short_term") +async def short_term_configs( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + # 获取短期记忆数据 + short_term=ShortService(end_user_id) + short_result=short_term.get_short_databasets() + short_count=short_term.get_short_count() + + long_term=LongService(end_user_id) + long_result=long_term.get_long_databasets() + + entity_result = await search_entity(end_user_id) + result = { + 'short_term': short_result, + 'long_term': long_result, + 'entity': entity_result.get('num', 0), + "retrieval_number":short_count, + "long_term_number":len(long_result) + } + + return success(data=result, msg="短期记忆系统数据获取成功") \ No newline at end of file diff --git a/api/app/controllers/memory_working_controller.py b/api/app/controllers/memory_working_controller.py new file mode 100644 index 00000000..dfd64044 --- /dev/null +++ b/api/app/controllers/memory_working_controller.py @@ -0,0 +1,134 @@ +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from app.core.logging_config import get_api_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.schemas.response_schema import ApiResponse +from app.services.conversation_service import ConversationService + +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/work", + tags=["Working Memory System"], + dependencies=[Depends(get_current_user)] +) + + +@router.get("/{group_id}/count", response_model=ApiResponse) +def get_memory_count( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + pass + + +@router.get("/{group_id}/conversations", response_model=ApiResponse) +def get_conversations( + group_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + Retrieve all conversations for the current user in a specific group. + + Args: + group_id (UUID): The group identifier. + current_user (User, optional): The authenticated user. + db (Session, optional): SQLAlchemy session. + + Returns: + ApiResponse: Contains a list of conversation IDs. + + Notes: + - Initializes the ConversationService with the current DB session. + - Returns only conversation IDs for lightweight response. + - Logs can be added to trace requests in production. + """ + conversation_service = ConversationService(db) + conversations = conversation_service.get_user_conversations( + group_id + ) + return success(data=[ + { + "id": conversation.id, + "title": conversation.title + } for conversation in conversations + ], msg="get conversations success") + + +@router.get("/{group_id}/messages", response_model=ApiResponse) +def get_messages( + conversation_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + Retrieve the message history for a specific conversation. + + Args: + conversation_id (UUID): The ID of the conversation to fetch messages from. + current_user (User, optional): The authenticated user. + db (Session, optional): SQLAlchemy session. + + Returns: + ApiResponse: Contains the list of messages in the conversation. + + Notes: + - Uses ConversationService to fetch messages. + - Consider paginating results if message history is large. + - Logging can be added for audit and debugging. + """ + conversation_service = ConversationService(db) + messages_obj = conversation_service.get_messages( + conversation_id, + ) + messages = [ + { + "role": message.role, + "content": message.content, + "created_at": int(message.created_at.timestamp() * 1000), + } + for message in messages_obj + ] + return success(data=messages, msg="get conversation history success") + + +@router.get("/{group_id}/detail", response_model=ApiResponse) +async def get_conversation_detail( + conversation_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + Retrieve detailed information about a specific conversation. + + This endpoint will fetch the conversation detail for the user. If the detail + does not exist or is outdated, it will trigger the LLM to generate a new summary. + + Args: + conversation_id (UUID): The ID of the conversation. + current_user (User, optional): The authenticated user making the request. + db (Session, optional): SQLAlchemy session. + + Returns: + ApiResponse: Contains the conversation detail serialized as a dictionary. + + Notes: + - Uses async ConversationService to fetch or generate the conversation detail. + - Handles workspace and user-specific context automatically. + - Logging and exception handling should be implemented for production monitoring. + """ + conversation_service = ConversationService(db) + detail = await conversation_service.get_conversation_detail( + user=current_user, + conversation_id=conversation_id, + workspace_id=current_user.current_workspace_id + ) + return success(data=detail.model_dump(), msg="get conversation detail success") diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 2069dd66..dba52d0b 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -108,16 +108,23 @@ async def get_prompt_opt( service = PromptOptimizerService(db) async def event_generator(): - async for chunk in service.optimize_prompt( - tenant_id=current_user.tenant_id, - model_id=data.model_id, - session_id=session_id, - user_id=current_user.id, - current_prompt=data.current_prompt, - user_require=data.message - ): - # chunk 是 prompt 的增量内容 - yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" + yield "event:start\ndata: {}\n\n" + try: + async for chunk in service.optimize_prompt( + tenant_id=current_user.tenant_id, + model_id=data.model_id, + session_id=session_id, + user_id=current_user.id, + current_prompt=data.current_prompt, + user_require=data.message + ): + # chunk 是 prompt 的增量内容 + yield f"event:message\ndata: {json.dumps(chunk)}\n\n" + except Exception as e: + yield f"event:error\ndata: {json.dumps( + {"error": str(e)} + )}\n\n" + yield "event:end\ndata: {}\n\n" return StreamingResponse( event_generator(), diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index a7a6203d..02c73718 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -1,4 +1,5 @@ import hashlib +import json import uuid from typing import Annotated from fastapi import APIRouter, Depends, Query, Request @@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.app_chat_service import AppChatService, get_app_chat_service -from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release +from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) logger = get_business_logger() @@ -288,7 +289,7 @@ async def chat( password = None # Token 认证不需要密码 # end_user_id = user_id other_id = user_id - + # 提前验证和准备(在流式响应开始前完成) # 这样可以确保错误能正确返回,而不是在流式响应中间出错 from app.models.app_model import AppType @@ -364,6 +365,9 @@ async def chat( config = release.config or {} if not config.get("sub_agents"): raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING) + elif app_type == AppType.WORKFLOW: + # Multi-Agent 类型:验证多 Agent 配置 + pass else: raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) @@ -392,6 +396,8 @@ async def chat( if app_type == AppType.AGENT: # 流式返回 + agent_config = agent_config_4_app_release(release) + if payload.stream: # async def event_generator(): # async for event in service.chat_stream( @@ -424,10 +430,11 @@ async def chat( user_id= str(new_end_user.id), # 转换为字符串 variables=payload.variables, web_search=payload.web_search, - config=payload.agent_config, + config=agent_config, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -463,10 +470,12 @@ async def chat( web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: + # config = workflow_config_4_app_release(release) config = multi_agent_config_4_app_release(release) if payload.stream: async def event_generator(): @@ -479,8 +488,8 @@ async def chat( config=config, web_search=payload.web_search, memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -551,8 +560,71 @@ async def chat( # ) # return success(data=conversation_schema.ChatResponse(**result)) + elif app_type == AppType.WORKFLOW: + + config = workflow_config_4_app_release(release) + if payload.stream: + async def event_generator(): + async for event in app_chat_service.workflow_chat_stream( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=new_end_user.id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=release.app_id, + workspace_id=workspace_id + ): + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 多 Agent 非流式返回 + result = await app_chat_service.workflow_chat( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=new_end_user.id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=release.app_id, + workspace_id=workspace_id + ) + logger.debug( + "工作流试运行返回结果", + extra={ + "result_type": str(type(result)), + "has_response": "response" in result if isinstance(result, dict) else False + } + ) + return success( + data=result, + msg="工作流任务执行成功" + ) + # return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) + else: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 5a78a28b..583b4700 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -1,4 +1,5 @@ """App 服务接口 - 基于 API Key 认证""" +import json from typing import Annotated from fastapi import APIRouter, Depends, Request, Body @@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth from app.services import workspace_service from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.conversation_service import ConversationService, get_conversation_service -from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release +from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release from app.services.app_service import get_app_service, AppService router = APIRouter(prefix="/app", tags=["V1 - App API"]) @@ -153,7 +154,8 @@ async def chat( config=agent_config, memory=memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -177,7 +179,8 @@ async def chat( web_search=web_search, memory=memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: @@ -194,8 +197,8 @@ async def chat( config=config, web_search=web_search, memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -211,7 +214,6 @@ async def chat( # 多 Agent 非流式返回 result = await app_chat_service.multi_agent_chat( - message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID user_id=end_user_id, # 转换为字符串 @@ -226,22 +228,29 @@ async def chat( return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.WORKFLOW: # 多 Agent 流式返回 - config = dict_to_workflow_config(app.current_release.config,app.id) + config = workflow_config_4_app_release(app.current_release) if payload.stream: async def event_generator(): async for event in app_chat_service.workflow_chat_stream( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=end_user_id, # 转换为字符串 + user_id=new_end_user.id, # 转换为字符串 variables=payload.variables, config=config, - web_search=web_search, - memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=app.app_id, + workspace_id=workspace_id ): - yield event + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message return StreamingResponse( event_generator(), @@ -253,23 +262,34 @@ async def chat( } ) - # 非流式返回 + # 多 Agent 非流式返回 result = await app_chat_service.workflow_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=end_user_id, # 转换为字符串 + user_id=new_end_user.id, # 转换为字符串 variables=payload.variables, config=config, - web_search=web_search, - memory=memory, + web_search=payload.web_search, + memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + app_id=app.app_id, + workspace_id=workspace_id + ) + logger.debug( + "工作流试运行返回结果", + extra={ + "result_type": str(type(result)), + "has_response": "response" in result if isinstance(result, dict) else False + } + ) + return success( + data=result, + msg="工作流任务执行成功" ) - - return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) else: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass + diff --git a/api/app/controllers/test_controller.py b/api/app/controllers/test_controller.py index 98cbe26e..5746405a 100644 --- a/api/app/controllers/test_controller.py +++ b/api/app/controllers/test_controller.py @@ -1,23 +1,22 @@ -from fastapi import APIRouter, Depends, status, Query, HTTPException -from langchain_core.messages import HumanMessage, SystemMessage +from fastapi import APIRouter, Depends, status, HTTPException, Body, Path +from fastapi.responses import StreamingResponse from langchain_core.prompts import ChatPromptTemplate from sqlalchemy.orm import Session -from typing import List, Optional import uuid - from app.core.models import RedBearLLM, RedBearRerank from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings from app.db import get_db -from app.dependencies import get_current_user -from app.models.models_model import ModelApiKey, ModelProvider, ModelType -from app.models.user_model import User -from app.schemas import model_schema +from app.models.models_model import ModelApiKey from app.core.response_utils import success -from app.schemas.response_schema import ApiResponse, PageData -from app.services.model_service import ModelConfigService, ModelApiKeyService +from app.schemas.response_schema import ApiResponse +from app.schemas.app_schema import AppChatRequest +from app.services.model_service import ModelConfigService +from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache +from app.services.conversation_service import ConversationService from app.core.logging_config import get_api_logger +from app.dependencies import get_current_user # 获取API专用日志器 api_logger = get_api_logger() @@ -28,6 +27,8 @@ router = APIRouter( ) +# ==================== 原有测试接口 ==================== + @router.get("/llm/{model_id}", response_model=ApiResponse) def test_llm( model_id: uuid.UUID, @@ -50,7 +51,6 @@ def test_llm( template = """Question: {question} Answer: Let's think step by step.""" - # ChatPromptTemplate prompt = ChatPromptTemplate.from_template(template) chain = prompt | llm answer = chain.invoke({"question": "What is LangChain?"}) @@ -80,13 +80,13 @@ def test_embedding( base_url=apiConfig.api_base )) - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] + data = [ + "最近哪家咖啡店评价最好?", + "附近有没有推荐的咖啡厅?", + "明天天气预报说会下雨。", + "北京是中国的首都。", + "我想找一个适合学习的地方。" + ] embeddings = model.embed_documents(data) print(embeddings) query = "我想找一个适合学习的地方。" @@ -114,13 +114,123 @@ def test_rerank( base_url=apiConfig.api_base )) query = "最近哪家咖啡店评价最好?" - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] + data = [ + "最近哪家咖啡店评价最好?", + "附近有没有推荐的咖啡厅?", + "明天天气预报说会下雨。", + "北京是中国的首都。", + "我想找一个适合学习的地方。" + ] scores = model.rerank(query=query, documents=data, top_n=3) print(scores) return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores}) + + +# ==================== Handoffs 测试接口 ==================== + +@router.post("/handoffs/{app_id}") +async def test_handoffs( + app_id: uuid.UUID = Path(..., description="应用 ID"), + request: AppChatRequest = Body(...), + current_user=Depends(get_current_user), + db: Session = Depends(get_db) +): + """测试 Agent Handoffs 功能 + + 演示 LangGraph 实现的多 Agent 协作和动态切换 + + - 从数据库 multi_agent_config 获取 Agent 配置 + - 根据用户问题自动切换到合适的 Agent + - 使用 conversation_id 保持会话状态 + - 通过 stream 参数控制是否流式输出 + + 事件类型(流式): + - start: 开始执行 + - agent: 当前 Agent 信息 + - message: 流式消息内容 + - handoff: Agent 切换事件 + - end: 执行结束 + - error: 错误信息 + """ + try: + workspace_id = current_user.current_workspace_id + + # 获取或创建会话 + conversation_service = ConversationService(db) + + if request.conversation_id: + # 验证会话存在 + conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id)) + if not conversation: + raise HTTPException(status_code=404, detail="会话不存在") + conversation_id = str(conversation.id) + else: + # 创建新会话 + conversation = conversation_service.create_or_get_conversation( + app_id=app_id, + workspace_id=workspace_id, + user_id=request.user_id, + is_draft=True + ) + conversation_id = str(conversation.id) + + # 根据 stream 参数决定返回方式 + if request.stream: + # 流式返回 + service = get_handoffs_service_for_app(app_id, db, streaming=True) + return StreamingResponse( + service.chat_stream( + message=request.message, + conversation_id=conversation_id + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + else: + # 非流式返回 + service = get_handoffs_service_for_app(app_id, db, streaming=False) + result = await service.chat( + message=request.message, + conversation_id=conversation_id + ) + return success(data=result, msg="Handoffs 测试成功") + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Handoffs 测试失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse) +def get_handoff_agents( + app_id: uuid.UUID = Path(..., description="应用 ID"), + db: Session = Depends(get_db), + current_user=Depends(get_current_user) +): + """获取应用的 Handoff Agent 列表""" + try: + service = get_handoffs_service_for_app(app_id, db, streaming=False) + agents = service.get_agents() + return success(data={"agents": agents}, msg="获取 Agent 列表成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + api_logger.error(f"获取 Agent 列表失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/handoffs/{app_id}/reset") +def reset_handoff_service( + app_id: uuid.UUID = Path(..., description="应用 ID"), + current_user=Depends(get_current_user) +): + """重置指定应用的 Handoff 服务缓存""" + reset_handoffs_service_cache(app_id) + return success(msg="Handoff 服务已重置") diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 479686ef..a3624ea4 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -215,8 +215,8 @@ async def sync_mcp_tools( """同步MCP工具列表""" try: result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) - if result["success"] is False: - raise HTTPException(status_code=404, detail=result["message"]) + if not result.get("success", False): + raise HTTPException(status_code=400, detail=result.get("message", "同步失败")) return success(data=result, msg="MCP工具列表同步完成") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index 15c50601..5fd9b841 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -11,14 +11,22 @@ from app.db import get_db from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode +from app.core.api_key_utils import timestamp_to_datetime from app.services.user_memory_service import ( UserMemoryService, - analytics_node_statistics, analytics_memory_types, analytics_graph_data, ) +from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest +from app.schemas.user_memory_schema import ( + EpisodicMemoryOverviewRequest, + EpisodicMemoryDetailsRequest, + ExplicitMemoryOverviewRequest, + ExplicitMemoryDetailsRequest, +) + from app.schemas.end_user_schema import ( EndUserProfileResponse, EndUserProfileUpdate, @@ -41,24 +49,27 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, # 使用 end_user_id + end_user_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: - """获取缓存的记忆洞察报告""" - api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}") +) -> dict: + """ + 获取缓存的记忆洞察报告 + + 此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。 + 如需生成新的洞察报告,请使用专门的生成接口。 + """ + api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_memory_insight(db, end_user_id) - + if result["is_cached"]: - # 缓存存在,返回缓存数据 api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") return success(data=result, msg="查询成功") else: - # 缓存不存在,返回提示消息 api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=result, msg="数据尚未生成") except Exception as e: api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) @@ -66,24 +77,27 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, # 使用 end_user_id + end_user_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: - """获取缓存的用户摘要""" - api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}") +) -> dict: + """ + 获取缓存的用户摘要 + + 此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。 + 如需生成新的用户摘要,请使用专门的生成接口。 + """ + api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_user_summary(db, end_user_id) - + if result["is_cached"]: - # 缓存存在,返回缓存数据 api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") return success(data=result, msg="查询成功") else: - # 缓存不存在,返回提示消息 api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=result, msg="数据尚未生成") except Exception as e: api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) @@ -97,35 +111,35 @@ async def generate_cache_api( ) -> dict: """ 手动触发缓存生成 - + - 如果提供 end_user_id,只为该用户生成 - 如果不提供,为当前工作空间的所有用户生成 """ workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + group_id = request.end_user_id - + api_logger.info( f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " f"end_user_id={group_id if group_id else '全部用户'}" ) - + try: if group_id: # 为单个用户生成 api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") - + # 生成记忆洞察 insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) - + # 生成用户摘要 summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) - + # 构建响应 result = { "end_user_id": group_id, @@ -133,7 +147,7 @@ async def generate_cache_api( "summary_success": summary_result["success"], "errors": [] } - + # 收集错误信息 if not insight_result["success"]: result["errors"].append({ @@ -145,29 +159,29 @@ async def generate_cache_api( "type": "summary", "error": summary_result.get("error") }) - + # 记录结果 if result["insight_success"] and result["summary_success"]: api_logger.info(f"成功为用户 {group_id} 生成缓存") else: api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") - + return success(data=result, msg="生成完成") - + else: # 为整个工作空间生成 api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存") - + result = await user_memory_service.generate_cache_for_workspace(db, workspace_id) - + # 记录统计信息 api_logger.info( f"工作空间 {workspace_id} 批量生成完成: " f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}" ) - + return success(data=result, msg="批量生成完成") - + except Exception as e: api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e)) @@ -180,18 +194,18 @@ async def get_node_statistics_api( db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") - + try: # 调用新的记忆类型统计函数 result = await analytics_memory_types(db, end_user_id) - + # 计算总数用于日志 total_count = sum(item["count"] for item in result) api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") @@ -211,31 +225,31 @@ async def get_graph_data_api( db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + # 参数验证 if limit > 1000: limit = 1000 api_logger.warning("limit 参数超过最大值,已调整为 1000") - + if depth > 3: depth = 3 api_logger.warning("depth 参数超过最大值,已调整为 3") - + # 解析 node_types 参数 node_types_list = None if node_types: node_types_list = [t.strip() for t in node_types.split(",") if t.strip()] - + api_logger.info( f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, " f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}" ) - + try: result = await analytics_graph_data( db=db, @@ -245,19 +259,19 @@ async def get_graph_data_api( depth=depth, center_node_id=center_node_id ) - + # 检查是否有错误消息 if "message" in result and result["statistics"]["total_nodes"] == 0: api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") return success(data=result, msg=result.get("message", "查询成功")) - + api_logger.info( f"成功获取图数据: end_user_id={end_user_id}, " f"nodes={result['statistics']['total_nodes']}, " f"edges={result['statistics']['total_edges']}" ) return success(data=result, msg="查询成功") - + except Exception as e: api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) @@ -270,25 +284,25 @@ async def get_end_user_profile( db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - + try: # 查询终端用户 end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() - + if not end_user: api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - + # 构建响应数据 profile_data = EndUserProfileResponse( id=end_user.id, @@ -300,10 +314,10 @@ async def get_end_user_profile( hire_date=end_user.hire_date, updatetime_profile=end_user.updatetime_profile ) - + api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}") return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功") - + except Exception as e: api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e)) @@ -317,56 +331,56 @@ async def update_end_user_profile( ) -> dict: """ 更新终端用户的基本信息 - + 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 所有字段都是可选的,只更新提供的字段。 - + """ workspace_id = current_user.current_workspace_id end_user_id = profile_update.end_user_id - + # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - + api_logger.info( f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, " f"workspace={workspace_id}" ) - + try: # 查询终端用户 end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() - + if not end_user: api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - + # 更新字段(只更新提供的字段,排除 end_user_id) # 允许 None 值来重置字段(如 hire_date) update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) - + # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime if 'hire_date' in update_data: hire_date_timestamp = update_data['hire_date'] if hire_date_timestamp is not None: - update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp) + update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) # 如果是 None,保持 None(允许清空) - + for field, value in update_data.items(): setattr(end_user, field, value) - + # 更新 updated_at 时间戳 end_user.updated_at = datetime.datetime.now() - + # 更新 updatetime_profile 为当前时间 end_user.updatetime_profile = datetime.datetime.now() - + # 提交更改 db.commit() db.refresh(end_user) - + # 构建响应数据 profile_data = EndUserProfileResponse( id=end_user.id, @@ -378,11 +392,243 @@ async def update_end_user_profile( hire_date=end_user.hire_date, updatetime_profile=end_user.updatetime_profile ) - + api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功") - + except Exception as e: db.rollback() api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) + +@router.get("/memory_space/timeline_memories", response_model=ApiResponse) +async def memory_space_timeline_of_shared_memories(id: str, label: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): + MemoryEntity = MemoryEntityService(id, label) + timeline_memories_result = await MemoryEntity.get_timeline_memories_server() + return success(data=timeline_memories_result, msg="共同记忆时间线") +@router.get("/memory_space/relationship_evolution", response_model=ApiResponse) +async def memory_space_relationship_evolution(id: str, label: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): + try: + api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") + + # 获取情绪数据 + emotion = MemoryEmotion(id, label) + emotion_result = await emotion.get_emotion() + + # 获取交互数据 + interaction = MemoryInteraction(id, label) + interaction_result = await interaction.get_interaction_frequency() + + # 关闭连接 + await emotion.close() + await interaction.close() + + result = { + "emotion": emotion_result, + "interaction": interaction_result + } + + api_logger.info(f"关系演变查询成功: id={id}, table={label}") + return success(data=result, msg="关系演变") + + except Exception as e: + api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e)) + + +@router.post("/classifications/episodic-memory", response_model=ApiResponse) +async def get_episodic_memory_overview_api( + request: EpisodicMemoryOverviewRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 获取情景记忆总览 + + 返回指定用户的所有情景记忆列表,包括标题和创建时间。 + 支持通过时间范围、情景类型和标题关键词进行筛选。 + + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 验证参数 + valid_time_ranges = ["all", "today", "this_week", "this_month"] + valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"] + + if request.time_range not in valid_time_ranges: + return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}") + + if request.episodic_type not in valid_episodic_types: + return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}") + + # 处理 title_keyword(去除首尾空格) + title_keyword = request.title_keyword.strip() if request.title_keyword else None + + api_logger.info( + f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, " + f"title_keyword={title_keyword}" + ) + + try: + # 调用Service层方法 + result = await user_memory_service.get_episodic_memory_overview( + db, request.end_user_id, request.time_range, request.episodic_type, title_keyword + ) + + api_logger.info( + f"成功获取情景记忆总览: end_user_id={request.end_user_id}, " + f"total={result['total']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e)) + + +@router.post("/classifications/episodic-memory-details", response_model=ApiResponse) +async def get_episodic_memory_details_api( + request: EpisodicMemoryDetailsRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 获取情景记忆详情 + + 返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。 + + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, " + f"user={current_user.username}, workspace={workspace_id}" + ) + + try: + # 调用Service层方法 + result = await user_memory_service.get_episodic_memory_details( + db=db, + end_user_id=request.end_user_id, + summary_id=request.summary_id + ) + + api_logger.info( + f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}" + ) + return success(data=result, msg="查询成功") + + except ValueError as e: + # 处理情景记忆不存在的情况 + api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}") + return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e)) + except Exception as e: + api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e)) + + +@router.post("/classifications/explicit-memory", response_model=ApiResponse) +async def get_explicit_memory_overview_api( + request: ExplicitMemoryOverviewRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 获取显性记忆总览 + + 返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + # 调用Service层方法 + result = await user_memory_service.get_explicit_memory_overview( + db, request.end_user_id + ) + + api_logger.info( + f"成功获取显性记忆总览: end_user_id={request.end_user_id}, " + f"total={result['total']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e)) + + +@router.post("/classifications/explicit-memory-details", response_model=ApiResponse) +async def get_explicit_memory_details_api( + request: ExplicitMemoryDetailsRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 获取显性记忆详情 + + 根据 memory_id 返回情景记忆或语义记忆的详细信息。 + - 情景记忆:包括标题、内容、情绪、创建时间 + - 语义记忆:包括名称、核心定义、详细笔记、创建时间 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, " + f"user={current_user.username}, workspace={workspace_id}" + ) + + try: + # 调用Service层方法 + result = await user_memory_service.get_explicit_memory_details( + db=db, + end_user_id=request.end_user_id, + memory_id=request.memory_id + ) + + api_logger.info( + f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, " + f"memory_type={result.get('memory_type')}" + ) + return success(data=result, msg="查询成功") + + except ValueError as e: + # 处理记忆不存在的情况 + api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}") + return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e)) + except Exception as e: + api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e)) + + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 380b660c..91445b12 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,10 +11,16 @@ import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence + +from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.services.memory_agent_service import ( + get_end_user_connected_config, +) from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task @@ -159,11 +165,13 @@ class LangChainAgent: history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) # logger.info(f'Redis_Agent:{end_user_end};{history}') messagss_list=[] + retrieved_content=[] for messages in history: query = messages.get("Query") aimessages = messages.get("Answer") messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - return messagss_list + retrieved_content.append({query: aimessages}) + return messagss_list,retrieved_content async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): @@ -203,7 +211,6 @@ class LangChainAgent: # If config_id is None, try to get from end_user's connected config if actual_config_id is None and end_user_id: try: - from app.db import get_db from app.services.memory_agent_service import ( get_end_user_connected_config, ) @@ -221,11 +228,26 @@ class LangChainAgent: logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - history_term_memory=await self.term_memory_redis_read(end_user_id) + history_term_memory_result = await self.term_memory_redis_read(end_user_id) + history_term_memory = history_term_memory_result[0] + db_for_memory = next(get_db()) if memory_flag: if len(history_term_memory)>=4 and storage_type != "rag": - history_term_memory=';'.join(history_term_memory) - logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + history_term_memory = ';'.join(history_term_memory) + retrieved_content = history_term_memory_result[1] + print(retrieved_content) + # 为长期记忆操作获取新的数据库连接 + try: + repo = LongTermMemoryRepository(db_for_memory) + repo.upsert(end_user_id, retrieved_content) + logger.info( + f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + except Exception as e: + logger.error(f"Failed to write to LongTermMemory: {e}") + raise + finally: + db_for_memory.close() + await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) try: @@ -314,10 +336,6 @@ class LangChainAgent: # If config_id is None, try to get from end_user's connected config if actual_config_id is None and end_user_id: try: - from app.db import get_db - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) db = next(get_db()) try: connected_config = get_end_user_connected_config(end_user_id, db) @@ -329,14 +347,24 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") - history_term_memory = await self.term_memory_redis_read(end_user_id) + history_term_memory_result = await self.term_memory_redis_read(end_user_id) + history_term_memory = history_term_memory_result[0] if memory_flag: if len(history_term_memory) >= 4 and storage_type != "rag": history_term_memory = ';'.join(history_term_memory) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, - history_term_memory, actual_config_id) + retrieved_content = history_term_memory_result[1] + db_for_memory = next(get_db()) + try: + repo = LongTermMemoryRepository(db_for_memory) + repo.upsert(end_user_id, retrieved_content) + logger.info( + f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') + await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, + history_term_memory, actual_config_id) + except Exception as e: + logger.error(f"Failed to write to long term memory: {e}") + finally: + db_for_memory.close() await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) try: diff --git a/api/app/core/config.py b/api/app/core/config.py index 7494b89d..d9b9cea8 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -164,6 +164,10 @@ class Settings: TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10")) ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" + + # official environment system version + SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") + SYSTEM_INTRODUCTION: str = os.getenv("SYSTEM_INTRODUCTION", "") def get_memory_output_path(self, filename: str = "") -> str: """ diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index d0aa9cc1..23023ca4 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -82,6 +82,13 @@ class BizCode(IntEnum): MEMORY_WRITE_FAILED = 9501 MEMORY_READ_FAILED = 9502 MEMORY_CONFIG_NOT_FOUND = 9503 + + # Implicit Memory API(96xx) + INVALID_USER_ID = 9601 + INSUFFICIENT_DATA = 9602 + INVALID_FILTER_PARAMS = 9603 + ANALYSIS_FAILED = 9604 + PROFILE_STORAGE_ERROR = 9605 # 系统(100xx) INTERNAL_ERROR = 10001 @@ -159,6 +166,13 @@ HTTP_MAPPING = { BizCode.MEMORY_READ_FAILED: 500, BizCode.MEMORY_CONFIG_NOT_FOUND: 400, + # Implicit Memory API 错误码映射 + BizCode.INVALID_USER_ID: 400, + BizCode.INSUFFICIENT_DATA: 400, + BizCode.INVALID_FILTER_PARAMS: 400, + BizCode.ANALYSIS_FAILED: 500, + BizCode.PROFILE_STORAGE_ERROR: 500, + BizCode.INTERNAL_ERROR: 500, BizCode.DB_ERROR: 500, BizCode.SERVICE_UNAVAILABLE: 503, diff --git a/api/app/core/memory/analytics/__init__.py b/api/app/core/memory/analytics/__init__.py index 06aeaed3..6811ff8f 100644 --- a/api/app/core/memory/analytics/__init__.py +++ b/api/app/core/memory/analytics/__init__.py @@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system. Available functions: - get_hot_memory_tags: Get hot memory tags by frequency -- MemoryInsight: Generate memory insight reports - get_recent_activity_stats: Get recent activity statistics -- generate_user_summary: Generate user summary + +Note: MemoryInsight and generate_user_summary have been moved to +app.services.user_memory_service for better architecture. """ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.analytics.memory_insight import MemoryInsight from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats -from app.core.memory.analytics.user_summary import generate_user_summary __all__ = [ "get_hot_memory_tags", - "MemoryInsight", "get_recent_activity_stats", - "generate_user_summary", ] diff --git a/api/app/core/memory/analytics/implicit_memory/__init__.py b/api/app/core/memory/analytics/implicit_memory/__init__.py new file mode 100644 index 00000000..de10bc85 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/__init__.py @@ -0,0 +1,6 @@ +"""Implicit Memory Module + +This module provides behavior analysis capabilities that build comprehensive user profiles +by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across +multiple dimensions, tracks interest distributions, and identifies behavioral habits. +""" \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/__init__.py b/api/app/core/memory/analytics/implicit_memory/analyzers/__init__.py new file mode 100644 index 00000000..305b281c --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/__init__.py @@ -0,0 +1 @@ +"""Analyzers package for implicit memory analysis components.""" \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py new file mode 100644 index 00000000..521ac383 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py @@ -0,0 +1,271 @@ +"""Dimension Analyzer for Implicit Memory System + +This module implements LLM-based personality dimension analysis from user memory summaries. +It analyzes four key dimensions: creativity, aesthetic, technology, and literature, +providing percentage scores with evidence and reasoning. +""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.schemas.implicit_memory_schema import ( + DimensionPortrait, + DimensionScore, + UserMemorySummary, +) +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class DimensionData(BaseModel): + """Individual dimension analysis data.""" + percentage: float = Field(ge=0.0, le=100.0) + evidence: List[str] = Field(default_factory=list) + reasoning: str = "" + confidence_level: int = 50 # Default to medium confidence + + +class DimensionAnalysisResponse(BaseModel): + """Response model for dimension analysis.""" + dimensions: Dict[str, DimensionData] = Field(default_factory=dict) + + +class DimensionAnalyzer: + """Analyzes user memory summaries to extract personality dimensions.""" + + # Define the four dimensions we analyze + DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"] + + def __init__(self, db: Session, llm_model_id: Optional[str] = None): + """Initialize the dimension analyzer. + + Args: + db: Database session + llm_model_id: Optional LLM model ID to use for analysis + """ + self.db = db + self.llm_model_id = llm_model_id + self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id) + + async def analyze_dimensions( + self, + user_id: str, + user_summaries: List[UserMemorySummary], + existing_portrait: Optional[DimensionPortrait] = None + ) -> DimensionPortrait: + """Analyze user summaries to extract personality dimensions. + + Args: + user_id: Target user ID + user_summaries: List of user-specific memory summaries + existing_portrait: Optional existing portrait for incremental updates + + Returns: + Dimension portrait with four personality dimensions + + Raises: + LLMClientException: If LLM analysis fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for user {user_id}") + return self._create_empty_portrait(user_id) + + try: + logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries") + + # Use the LLM client wrapper for analysis + response = await self._llm_client.analyze_dimensions( + user_summaries=user_summaries, + user_id=user_id, + model_id=self.llm_model_id + ) + + # Create dimension scores + dimension_scores = {} + current_time = datetime.now() + + for dimension_name in self.DIMENSIONS: + # Handle response as dictionary + dimensions_data = response.get("dimensions", {}) + dimension_data = dimensions_data.get(dimension_name) + + if dimension_data: + # Validate and create dimension score + score = self._create_dimension_score( + dimension_name=dimension_name, + dimension_data=dimension_data + ) + dimension_scores[dimension_name] = score + else: + # Create default score if missing + logger.warning(f"Missing dimension data for {dimension_name}, using default") + dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name) + + # Create dimension portrait + portrait = DimensionPortrait( + user_id=user_id, + creativity=dimension_scores["creativity"], + aesthetic=dimension_scores["aesthetic"], + technology=dimension_scores["technology"], + literature=dimension_scores["literature"], + analysis_timestamp=current_time, + total_summaries_analyzed=len(user_summaries), + historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None + ) + + logger.info(f"Created dimension portrait for user {user_id}") + return portrait + + except LLMClientException: + raise + except Exception as e: + logger.error(f"Dimension analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Dimension analysis failed: {e}") from e + + def _create_dimension_score( + self, + dimension_name: str, + dimension_data: dict + ) -> DimensionScore: + """Create a dimension score from analysis data. + + Args: + dimension_name: Name of the dimension + dimension_data: Analysis data dictionary for the dimension + + Returns: + DimensionScore object + """ + # Validate percentage - handle dict access + percentage = dimension_data.get("percentage", 0.0) + percentage = max(0.0, min(100.0, float(percentage))) + + # Validate confidence level + confidence_level = self._validate_confidence_level(dimension_data.get("confidence_level", 50)) + + # Ensure evidence is not empty + evidence = dimension_data.get("evidence", []) + if not evidence: + evidence = ["No specific evidence found"] + + # Ensure reasoning is not empty + reasoning = dimension_data.get("reasoning", "") + if not reasoning: + reasoning = f"Analysis for {dimension_name} dimension" + + return DimensionScore( + dimension_name=dimension_name, + percentage=percentage, + evidence=evidence, + reasoning=reasoning, + confidence_level=confidence_level + ) + + def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore: + """Create a default dimension score when analysis fails. + + Args: + dimension_name: Name of the dimension + + Returns: + Default DimensionScore object + """ + return DimensionScore( + dimension_name=dimension_name, + percentage=0.0, + evidence=["Insufficient data for analysis"], + reasoning=f"No clear evidence found for {dimension_name} dimension", + confidence_level=20 # Low confidence as numerical value + ) + + def _validate_confidence_level(self, confidence_level) -> int: + """Return confidence level as integer, handling both string and numeric inputs. + + Args: + confidence_level: Confidence level (string or numeric) + + Returns: + Confidence level as integer (0-100) + """ + # If it's already a number, return it as int + if isinstance(confidence_level, (int, float)): + return int(confidence_level) + + # If it's a string, convert common values to numbers + if isinstance(confidence_level, str): + confidence_str = confidence_level.lower().strip() + if confidence_str in ["high", "높음"]: + return 85 + elif confidence_str in ["medium", "중간"]: + return 50 + elif confidence_str in ["low", "낮음"]: + return 20 + else: + # Try to parse as number + try: + return int(float(confidence_str)) + except ValueError: + logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium") + return 50 + + # Default fallback + return 50 + + def _create_empty_portrait(self, user_id: str) -> DimensionPortrait: + """Create an empty dimension portrait when no data is available. + + Args: + user_id: Target user ID + + Returns: + Empty DimensionPortrait + """ + current_time = datetime.now() + + return DimensionPortrait( + user_id=user_id, + creativity=self._create_default_dimension_score("creativity"), + aesthetic=self._create_default_dimension_score("aesthetic"), + technology=self._create_default_dimension_score("technology"), + literature=self._create_default_dimension_score("literature"), + analysis_timestamp=current_time, + total_summaries_analyzed=0, + historical_trends=None + ) + + def _calculate_historical_trends( + self, + existing_portrait: DimensionPortrait + ) -> List[Dict[str, Any]]: + """Calculate historical trends from existing portrait. + + Args: + existing_portrait: Previous dimension portrait + + Returns: + List of historical trend data + """ + if not existing_portrait: + return [] + + # Create trend entry from existing portrait + trend_entry = { + "timestamp": existing_portrait.analysis_timestamp.isoformat(), + "creativity": existing_portrait.creativity.percentage, + "aesthetic": existing_portrait.aesthetic.percentage, + "technology": existing_portrait.technology.percentage, + "literature": existing_portrait.literature.percentage, + "total_summaries": existing_portrait.total_summaries_analyzed + } + + # Combine with existing trends + existing_trends = existing_portrait.historical_trends or [] + + # Keep only recent trends (last 10 entries) + all_trends = existing_trends + [trend_entry] + return all_trends[-10:] \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/habit_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/habit_analyzer.py new file mode 100644 index 00000000..dbc0817d --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/habit_analyzer.py @@ -0,0 +1,459 @@ +"""Habit Analyzer for Implicit Memory System + +This module implements LLM-based behavioral habit analysis from user memory summaries. +It identifies recurring behavioral patterns, temporal patterns, and consolidates +similar habits with confidence scoring. +""" + +import logging +from datetime import datetime +from typing import List, Optional + +from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.schemas.implicit_memory_schema import ( + BehaviorHabit, + FrequencyPattern, + UserMemorySummary, +) +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class HabitData(BaseModel): + """Individual habit analysis data.""" + habit_description: str + frequency_pattern: str + time_context: str + confidence_level: int = 50 # Default to medium confidence + supporting_summaries: List[str] = Field(default_factory=list) + specific_examples: List[str] = Field(default_factory=list) + is_current: bool = True + + +class HabitAnalysisResponse(BaseModel): + """Response model for habit analysis.""" + habits: List[HabitData] = Field(default_factory=list) + + +class HabitAnalyzer: + """Analyzes user memory summaries to extract behavioral habits.""" + + def __init__(self, db: Session, llm_model_id: Optional[str] = None): + """Initialize the habit analyzer. + + Args: + db: Database session + llm_model_id: Optional LLM model ID to use for analysis + """ + self.db = db + self.llm_model_id = llm_model_id + self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id) + + async def analyze_habits( + self, + user_id: str, + user_summaries: List[UserMemorySummary], + existing_habits: Optional[List[BehaviorHabit]] = None + ) -> List[BehaviorHabit]: + """Analyze user summaries to extract behavioral habits. + + Args: + user_id: Target user ID + user_summaries: List of user-specific memory summaries + existing_habits: Optional existing habits for consolidation + + Returns: + List of extracted behavioral habits + + Raises: + LLMClientException: If LLM analysis fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for user {user_id}") + return existing_habits or [] + + try: + logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries") + + # Use the LLM client wrapper for analysis + response = await self._llm_client.analyze_habits( + user_summaries=user_summaries, + user_id=user_id, + model_id=self.llm_model_id + ) + + # Convert to BehaviorHabit objects + behavior_habits = [] + + for habit_data in response.get("habits", []): + try: + # Handle habit_data as dictionary + supporting_summaries = habit_data.get("supporting_summaries", []) + specific_examples = habit_data.get("specific_examples", []) + + # Determine observation dates from summaries + first_observed, last_observed = self._determine_observation_dates( + user_summaries, supporting_summaries + ) + + behavior_habit = BehaviorHabit( + habit_description=habit_data.get("habit_description", ""), + frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")), + time_context=habit_data.get("time_context", ""), + confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", 50)), + specific_examples=specific_examples, + first_observed=first_observed, + last_observed=last_observed, + is_current=habit_data.get("is_current", True) + ) + + # Validate habit + if self._is_valid_habit(behavior_habit): + behavior_habits.append(behavior_habit) + else: + logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}") + + except Exception as e: + logger.error(f"Error creating behavior habit: {e}") + continue + + # Consolidate with existing habits if provided + if existing_habits: + behavior_habits = self._consolidate_habits( + new_habits=behavior_habits, + existing_habits=existing_habits + ) + + # Sort habits by confidence and recency + behavior_habits = self._sort_habits_by_priority(behavior_habits) + + logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}") + return behavior_habits + + except LLMClientException: + raise + except Exception as e: + logger.error(f"Habit analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Habit analysis failed: {e}") from e + + def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern: + """Validate and convert frequency pattern string. + + Args: + frequency_str: Frequency pattern as string + + Returns: + FrequencyPattern enum value + """ + frequency_str = frequency_str.lower().strip() + + frequency_mapping = { + "daily": FrequencyPattern.DAILY, + "weekly": FrequencyPattern.WEEKLY, + "monthly": FrequencyPattern.MONTHLY, + "seasonal": FrequencyPattern.SEASONAL, + "occasional": FrequencyPattern.OCCASIONAL, + "event_triggered": FrequencyPattern.EVENT_TRIGGERED, + "event-triggered": FrequencyPattern.EVENT_TRIGGERED, + } + + return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL) + + def _validate_confidence_level(self, confidence_level) -> int: + """Return confidence level as integer, handling both string and numeric inputs. + + Args: + confidence_level: Confidence level (string or numeric) + + Returns: + Confidence level as integer (0-100) + """ + # If it's already a number, return it as int + if isinstance(confidence_level, (int, float)): + return int(confidence_level) + + # If it's a string, convert common values to numbers + if isinstance(confidence_level, str): + confidence_str = confidence_level.lower().strip() + if confidence_str in ["high", "높음"]: + return 85 + elif confidence_str in ["medium", "중간"]: + return 50 + elif confidence_str in ["low", "낮음"]: + return 20 + else: + # Try to parse as number + try: + return int(float(confidence_str)) + except ValueError: + logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium") + return 50 + + # Default fallback + return 50 + + def _determine_observation_dates( + self, + user_summaries: List[UserMemorySummary], + supporting_summary_ids: List[str] + ) -> tuple[datetime, datetime]: + """Determine first and last observation dates for a habit. + + Args: + user_summaries: List of user summaries + supporting_summary_ids: IDs of summaries supporting the habit + + Returns: + Tuple of (first_observed, last_observed) dates + """ + from datetime import timezone + + # Find summaries that support this habit + supporting_summaries = [ + summary for summary in user_summaries + if summary.summary_id in supporting_summary_ids + ] + + if not supporting_summaries: + # Use all summaries if no specific supporting summaries found + supporting_summaries = user_summaries + + if not supporting_summaries: + current_time = datetime.now(timezone.utc).replace(tzinfo=None) + return current_time, current_time + + # Get date range from supporting summaries - normalize to naive datetimes + timestamps = [] + for summary in supporting_summaries: + ts = summary.timestamp + # Convert to naive datetime if it's timezone-aware + if ts.tzinfo is not None: + ts = ts.replace(tzinfo=None) + timestamps.append(ts) + + first_observed = min(timestamps) + last_observed = max(timestamps) + + return first_observed, last_observed + + def _is_valid_habit(self, habit: BehaviorHabit) -> bool: + """Validate a behavioral habit. + + Args: + habit: Behavioral habit to validate + + Returns: + True if valid, False otherwise + """ + try: + # Check required fields + if not habit.habit_description or not habit.habit_description.strip(): + return False + + # Check time context + if not habit.time_context or not habit.time_context.strip(): + return False + + # Check supporting summaries + if not habit.specific_examples or len(habit.specific_examples) == 0: + return False + + # Check specific examples + if not habit.specific_examples or len(habit.specific_examples) == 0: + return False + + # Check observation dates + if habit.first_observed > habit.last_observed: + return False + + return True + + except Exception as e: + logger.error(f"Error validating habit: {e}") + return False + + def _consolidate_habits( + self, + new_habits: List[BehaviorHabit], + existing_habits: List[BehaviorHabit], + similarity_threshold: float = 0.7 + ) -> List[BehaviorHabit]: + """Consolidate new habits with existing ones. + + Args: + new_habits: Newly extracted habits + existing_habits: Existing habits + similarity_threshold: Threshold for considering habits similar + + Returns: + Consolidated list of habits + """ + consolidated = existing_habits.copy() + current_time = datetime.now() + + for new_habit in new_habits: + # Find similar existing habit + similar_habit = self._find_similar_habit( + new_habit, existing_habits, similarity_threshold + ) + + if similar_habit: + # Update existing habit + updated_habit = self._merge_habits(similar_habit, new_habit, current_time) + # Replace in consolidated list + for i, habit in enumerate(consolidated): + if habit.habit_description == similar_habit.habit_description: + consolidated[i] = updated_habit + break + else: + # Add as new habit + consolidated.append(new_habit) + + return consolidated + + def _find_similar_habit( + self, + target_habit: BehaviorHabit, + existing_habits: List[BehaviorHabit], + threshold: float + ) -> Optional[BehaviorHabit]: + """Find similar habit in existing list. + + Args: + target_habit: Habit to find similarity for + existing_habits: List of existing habits + threshold: Similarity threshold + + Returns: + Similar habit if found, None otherwise + """ + target_desc = target_habit.habit_description.lower().strip() + + for existing_habit in existing_habits: + existing_desc = existing_habit.habit_description.lower().strip() + + # Check description similarity + desc_similarity = self._calculate_text_similarity(target_desc, existing_desc) + + # Check frequency pattern match + frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern) + + # Check time context similarity + time_similarity = self._calculate_text_similarity( + target_habit.time_context.lower(), + existing_habit.time_context.lower() + ) + + # Combined similarity score + combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4) + if frequency_match: + combined_similarity += 0.1 # Bonus for frequency match + + if combined_similarity >= threshold: + return existing_habit + + return None + + def _calculate_text_similarity(self, text1: str, text2: str) -> float: + """Calculate simple text similarity based on common words. + + Args: + text1: First text + text2: Second text + + Returns: + Similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + # Simple word-based similarity + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1.intersection(words2) + union = words1.union(words2) + + return len(intersection) / len(union) if union else 0.0 + + def _merge_habits( + self, + existing_habit: BehaviorHabit, + new_habit: BehaviorHabit, + current_time: datetime + ) -> BehaviorHabit: + """Merge two similar habits. + + Args: + existing_habit: Existing habit + new_habit: New habit to merge + current_time: Current timestamp + + Returns: + Merged behavioral habit + """ + # Combine supporting summaries (using specific_examples instead) + combined_examples = list(set( + existing_habit.specific_examples + new_habit.specific_examples + )) + + # Combine specific examples + combined_examples = list(set( + existing_habit.specific_examples + new_habit.specific_examples + )) + + # Update confidence level (take higher confidence) + new_confidence = max(existing_habit.confidence_level, new_habit.confidence_level) + + # Update observation dates + first_observed = min(existing_habit.first_observed, new_habit.first_observed) + last_observed = max(existing_habit.last_observed, new_habit.last_observed) + + # Determine if habit is current (observed within last 30 days) + is_current = (current_time - last_observed).days <= 30 + + # Combine time context + combined_time_context = existing_habit.time_context + if new_habit.time_context and new_habit.time_context not in combined_time_context: + combined_time_context += f"; {new_habit.time_context}" + + return BehaviorHabit( + habit_description=existing_habit.habit_description, # Keep original description + frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency + time_context=combined_time_context, + confidence_level=new_confidence, + specific_examples=combined_examples, + first_observed=first_observed, + last_observed=last_observed, + is_current=is_current + ) + + def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]: + """Sort habits by confidence level and recency. + + Args: + habits: List of habits to sort + + Returns: + Sorted list of habits + """ + def priority_score(habit: BehaviorHabit) -> tuple: + # Confidence level score (0-100 scale) + confidence_score = habit.confidence_level + + # Recency score (more recent = higher score) + days_since_last = (datetime.now() - habit.last_observed).days + recency_score = max(0, 365 - days_since_last) # Max 365 days + + # Current habit bonus + current_bonus = 100 if habit.is_current else 0 + + return (confidence_score, recency_score + current_bonus, habit.last_observed) + + return sorted(habits, key=priority_score, reverse=True) \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py new file mode 100644 index 00000000..dc65d740 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py @@ -0,0 +1,277 @@ +"""Interest Analyzer for Implicit Memory System + +This module implements LLM-based interest area analysis from user memory summaries. +It categorizes user interests into four areas: tech, lifestyle, music, and art, +providing percentage distribution that totals 100%. +""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.schemas.implicit_memory_schema import ( + InterestAreaDistribution, + InterestCategory, + UserMemorySummary, +) +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class InterestData(BaseModel): + """Individual interest category analysis data.""" + percentage: float = Field(ge=0.0, le=100.0) + evidence: List[str] = Field(default_factory=list) + trending_direction: Optional[str] = None + + +class InterestAnalysisResponse(BaseModel): + """Response model for interest analysis.""" + interest_distribution: Dict[str, InterestData] = Field(default_factory=dict) + + +class InterestAnalyzer: + """Analyzes user memory summaries to extract interest area distribution.""" + + # Define the four interest categories we analyze + INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"] + + def __init__(self, db: Session, llm_model_id: Optional[str] = None): + """Initialize the interest analyzer. + + Args: + db: Database session + llm_model_id: Optional LLM model ID to use for analysis + """ + self.db = db + self.llm_model_id = llm_model_id + self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id) + + async def analyze_interests( + self, + user_id: str, + user_summaries: List[UserMemorySummary], + existing_distribution: Optional[InterestAreaDistribution] = None + ) -> InterestAreaDistribution: + """Analyze user summaries to extract interest area distribution. + + Args: + user_id: Target user ID + user_summaries: List of user-specific memory summaries + existing_distribution: Optional existing distribution for trend tracking + + Returns: + Interest area distribution across four categories + + Raises: + LLMClientException: If LLM analysis fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for user {user_id}") + return self._create_empty_distribution(user_id) + + try: + logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries") + + # Use the LLM client wrapper for analysis + response = await self._llm_client.analyze_interests( + user_summaries=user_summaries, + user_id=user_id, + model_id=self.llm_model_id + ) + + # Create interest categories + interest_categories = {} + current_time = datetime.now() + + # Extract interest_distribution from response dict + interest_distribution = response.get("interest_distribution", {}) + + # Extract and validate interest data + raw_interests = {} + for category_name in self.INTEREST_CATEGORIES: + interest_data_dict = interest_distribution.get(category_name) + if interest_data_dict: + raw_interests[category_name] = InterestData( + percentage=interest_data_dict.get("percentage", 0.0), + evidence=interest_data_dict.get("evidence", []), + trending_direction=interest_data_dict.get("trending_direction") + ) + else: + # Create default if missing + logger.warning(f"Missing interest data for {category_name}, using default") + raw_interests[category_name] = InterestData( + percentage=0.0, + evidence=["No specific evidence found"], + trending_direction=None + ) + + # Normalize percentages to ensure they sum to 100% + normalized_interests = self._normalize_percentages(raw_interests) + + # Create interest category objects + for category_name in self.INTEREST_CATEGORIES: + interest_data = normalized_interests[category_name] + + # Calculate trending direction if we have existing data + trending_direction = self._calculate_trending_direction( + category_name=category_name, + current_percentage=interest_data.percentage, + existing_distribution=existing_distribution + ) if existing_distribution else interest_data.trending_direction + + interest_categories[category_name] = InterestCategory( + category_name=category_name, + percentage=interest_data.percentage, + evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"], + trending_direction=trending_direction + ) + + # Create interest area distribution + distribution = InterestAreaDistribution( + user_id=user_id, + tech=interest_categories["tech"], + lifestyle=interest_categories["lifestyle"], + music=interest_categories["music"], + art=interest_categories["art"], + analysis_timestamp=current_time, + total_summaries_analyzed=len(user_summaries) + ) + + # Validate that percentages sum to 100% + total_percentage = distribution.total_percentage + if not (99.9 <= total_percentage <= 100.1): + logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%") + + logger.info(f"Created interest distribution for user {user_id}") + return distribution + + except LLMClientException: + raise + except Exception as e: + logger.error(f"Interest analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Interest analysis failed: {e}") from e + + def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]: + """Normalize percentages to ensure they sum to 100%. + + Args: + raw_interests: Raw interest data with potentially unnormalized percentages + + Returns: + Normalized interest data + """ + # Calculate current total + total = sum(interest.percentage for interest in raw_interests.values()) + + if total == 0: + # If all percentages are 0, distribute equally + equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES) + normalized = {} + for category_name, interest_data in raw_interests.items(): + normalized[category_name] = InterestData( + percentage=equal_percentage, + evidence=interest_data.evidence, + trending_direction=interest_data.trending_direction + ) + return normalized + + # Normalize to sum to 100% + normalization_factor = 100.0 / total + normalized = {} + + for category_name, interest_data in raw_interests.items(): + normalized_percentage = interest_data.percentage * normalization_factor + + normalized[category_name] = InterestData( + percentage=round(normalized_percentage, 1), + evidence=interest_data.evidence, + trending_direction=interest_data.trending_direction + ) + + # Handle rounding errors by adjusting the largest category + current_total = sum(interest.percentage for interest in normalized.values()) + if abs(current_total - 100.0) > 0.1: + # Find category with largest percentage and adjust + largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage) + adjustment = 100.0 - current_total + + adjusted_percentage = normalized[largest_category].percentage + adjustment + normalized[largest_category] = InterestData( + percentage=round(max(0.0, adjusted_percentage), 1), + evidence=normalized[largest_category].evidence, + trending_direction=normalized[largest_category].trending_direction + ) + + return normalized + + def _calculate_trending_direction( + self, + category_name: str, + current_percentage: float, + existing_distribution: InterestAreaDistribution, + threshold: float = 5.0 + ) -> Optional[str]: + """Calculate trending direction for an interest category. + + Args: + category_name: Name of the interest category + current_percentage: Current percentage for the category + existing_distribution: Previous distribution for comparison + threshold: Minimum percentage change to consider a trend + + Returns: + Trending direction: "increasing", "decreasing", "stable", or None + """ + try: + # Get previous percentage + previous_category = getattr(existing_distribution, category_name, None) + if not previous_category: + return None + + previous_percentage = previous_category.percentage + change = current_percentage - previous_percentage + + if abs(change) < threshold: + return "stable" + elif change > 0: + return "increasing" + else: + return "decreasing" + + except Exception as e: + logger.error(f"Error calculating trending direction for {category_name}: {e}") + return None + + def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution: + """Create an empty interest distribution when no data is available. + + Args: + user_id: Target user ID + + Returns: + Empty InterestAreaDistribution with equal percentages + """ + current_time = datetime.now() + equal_percentage = 25.0 # 100% / 4 categories + + default_category = lambda name: InterestCategory( + category_name=name, + percentage=equal_percentage, + evidence=["Insufficient data for analysis"], + trending_direction=None + ) + + return InterestAreaDistribution( + user_id=user_id, + tech=default_category("tech"), + lifestyle=default_category("lifestyle"), + music=default_category("music"), + art=default_category("art"), + analysis_timestamp=current_time, + total_summaries_analyzed=0 + ) \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/preference_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/preference_analyzer.py new file mode 100644 index 00000000..418a3c37 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/preference_analyzer.py @@ -0,0 +1,302 @@ +"""Preference Analyzer for Implicit Memory System + +This module implements LLM-based preference extraction from user memory summaries. +It identifies implicit preferences, consolidates similar preferences, and calculates +confidence scores based on evidence strength. +""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.schemas.implicit_memory_schema import ( + PreferenceTag, + UserMemorySummary, +) +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class PreferenceAnalysisResponse(BaseModel): + """Response model for preference analysis.""" + preferences: List[Dict[str, Any]] = Field(default_factory=list) + + +class PreferenceAnalyzer: + """Analyzes user memory summaries to extract implicit preferences.""" + + def __init__(self, db: Session, llm_model_id: Optional[str] = None): + """Initialize the preference analyzer. + + Args: + db: Database session + llm_model_id: Optional LLM model ID to use for analysis + """ + self.db = db + self.llm_model_id = llm_model_id + self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id) + + async def analyze_preferences( + self, + user_id: str, + user_summaries: List[UserMemorySummary], + existing_preferences: Optional[List[PreferenceTag]] = None + ) -> List[PreferenceTag]: + """Analyze user summaries to extract preferences. + + Args: + user_id: Target user ID + user_summaries: List of user-specific memory summaries + existing_preferences: Optional existing preferences for consolidation + + Returns: + List of extracted preference tags + + Raises: + LLMClientException: If LLM analysis fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for user {user_id}") + return [] + + try: + logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries") + + # Use the LLM client wrapper for analysis + response = await self._llm_client.analyze_preferences( + user_summaries=user_summaries, + user_id=user_id, + model_id=self.llm_model_id + ) + + # Convert to PreferenceTag objects + preference_tags = [] + current_time = datetime.now() + + for pref_data in response.get("preferences", []): + try: + # Extract conversation references from summaries + conversation_refs = [s.summary_id for s in user_summaries] + + preference_tag = PreferenceTag( + tag_name=pref_data.get("tag_name", ""), + confidence_score=float(pref_data.get("confidence_score", 0.0)), + supporting_evidence=pref_data.get("supporting_evidence", []), + context_details=pref_data.get("context_details", ""), + category=pref_data.get("category"), + conversation_references=conversation_refs, + created_at=current_time, + updated_at=current_time + ) + + # Validate preference tag + if self._is_valid_preference(preference_tag): + preference_tags.append(preference_tag) + else: + logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}") + + except Exception as e: + logger.error(f"Error creating preference tag: {e}") + continue + + # Consolidate with existing preferences if provided + if existing_preferences: + preference_tags = self._consolidate_preferences( + new_preferences=preference_tags, + existing_preferences=existing_preferences + ) + + logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}") + return preference_tags + + except LLMClientException: + raise + except Exception as e: + logger.error(f"Preference analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Preference analysis failed: {e}") from e + + def _is_valid_preference(self, preference: PreferenceTag) -> bool: + """Validate a preference tag. + + Args: + preference: Preference tag to validate + + Returns: + True if valid, False otherwise + """ + try: + # Check required fields + if not preference.tag_name or not preference.tag_name.strip(): + return False + + # Check confidence score range + if not (0.0 <= preference.confidence_score <= 1.0): + return False + + # Check supporting evidence + if not preference.supporting_evidence or len(preference.supporting_evidence) == 0: + return False + + # Check context details + if not preference.context_details or not preference.context_details.strip(): + return False + + return True + + except Exception as e: + logger.error(f"Error validating preference: {e}") + return False + + def _consolidate_preferences( + self, + new_preferences: List[PreferenceTag], + existing_preferences: List[PreferenceTag], + similarity_threshold: float = 0.8 + ) -> List[PreferenceTag]: + """Consolidate new preferences with existing ones. + + Args: + new_preferences: Newly extracted preferences + existing_preferences: Existing preferences + similarity_threshold: Threshold for considering preferences similar + + Returns: + Consolidated list of preferences + """ + consolidated = existing_preferences.copy() + current_time = datetime.now() + + for new_pref in new_preferences: + # Find similar existing preference + similar_pref = self._find_similar_preference( + new_pref, existing_preferences, similarity_threshold + ) + + if similar_pref: + # Update existing preference + updated_pref = self._merge_preferences(similar_pref, new_pref, current_time) + # Replace in consolidated list + for i, pref in enumerate(consolidated): + if pref.tag_name == similar_pref.tag_name: + consolidated[i] = updated_pref + break + else: + # Add as new preference + consolidated.append(new_pref) + + return consolidated + + def _find_similar_preference( + self, + target_preference: PreferenceTag, + existing_preferences: List[PreferenceTag], + threshold: float + ) -> Optional[PreferenceTag]: + """Find similar preference in existing list. + + Args: + target_preference: Preference to find similarity for + existing_preferences: List of existing preferences + threshold: Similarity threshold + + Returns: + Similar preference if found, None otherwise + """ + target_name = target_preference.tag_name.lower().strip() + + for existing_pref in existing_preferences: + existing_name = existing_pref.tag_name.lower().strip() + + # Simple similarity check based on common words + similarity = self._calculate_text_similarity(target_name, existing_name) + + if similarity >= threshold: + return existing_pref + + return None + + def _calculate_text_similarity(self, text1: str, text2: str) -> float: + """Calculate simple text similarity based on common words. + + Args: + text1: First text + text2: Second text + + Returns: + Similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + # Simple word-based similarity + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1.intersection(words2) + union = words1.union(words2) + + return len(intersection) / len(union) if union else 0.0 + + def _merge_preferences( + self, + existing_pref: PreferenceTag, + new_pref: PreferenceTag, + current_time: datetime + ) -> PreferenceTag: + """Merge two similar preferences. + + Args: + existing_pref: Existing preference + new_pref: New preference to merge + current_time: Current timestamp + + Returns: + Merged preference tag + """ + # Combine supporting evidence + combined_evidence = list(set( + existing_pref.supporting_evidence + new_pref.supporting_evidence + )) + + # Combine conversation references + combined_refs = list(set( + existing_pref.conversation_references + new_pref.conversation_references + )) + + # Calculate new confidence score (weighted average) + evidence_weight = len(new_pref.supporting_evidence) + total_weight = len(existing_pref.supporting_evidence) + evidence_weight + + if total_weight > 0: + new_confidence = ( + (existing_pref.confidence_score * len(existing_pref.supporting_evidence) + + new_pref.confidence_score * evidence_weight) / total_weight + ) + else: + new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score) + + # Ensure confidence doesn't exceed 1.0 + new_confidence = min(new_confidence, 1.0) + + # Combine context details + combined_context = existing_pref.context_details + if new_pref.context_details and new_pref.context_details not in combined_context: + combined_context += f"; {new_pref.context_details}" + + return PreferenceTag( + tag_name=existing_pref.tag_name, # Keep original name + confidence_score=new_confidence, + supporting_evidence=combined_evidence, + context_details=combined_context, + category=existing_pref.category or new_pref.category, + conversation_references=combined_refs, + created_at=existing_pref.created_at, + updated_at=current_time + ) \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/data_source.py b/api/app/core/memory/analytics/implicit_memory/data_source.py new file mode 100644 index 00000000..d277a05e --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/data_source.py @@ -0,0 +1,97 @@ +""" +Memory Data Source + +Handles retrieval and processing of memory data from Neo4j using direct Cypher queries. +""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class MemoryDataSource: + """Retrieves processed memory data from Neo4j using direct Cypher queries.""" + + def __init__( + self, + db: Session, + neo4j_connector: Optional[Neo4jConnector] = None + ): + self.db = db + self.neo4j_connector = neo4j_connector or Neo4jConnector() + self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector) + + def _parse_timestamp(self, timestamp: Any) -> datetime: + """Parse timestamp from various formats.""" + if isinstance(timestamp, str): + return datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + elif timestamp is None: + return datetime.now() + return timestamp + + def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]: + """Convert a Neo4j dict directly to UserMemorySummary.""" + try: + content = summary_dict.get("content", summary_dict.get("summary", "")) + if not content or not content.strip(): + return None + + return UserMemorySummary( + summary_id=summary_dict.get("id", summary_dict.get("uuid", "")), + user_id=user_id, + user_content=content, + timestamp=self._parse_timestamp(summary_dict.get("created_at")), + confidence_score=1.0, + summary_type="memory_summary" + ) + except Exception as e: + logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}") + return None + + async def get_user_summaries( + self, + user_id: str, + time_range: Optional[TimeRange] = None, + limit: int = 1000 + ) -> List[UserMemorySummary]: + """Retrieve user memory summaries from Neo4j. + + Args: + user_id: Target user ID + time_range: Optional time range filter + limit: Maximum number of summaries + + Returns: + List of user memory summaries + """ + try: + start_date = time_range.start_date if time_range else None + end_date = time_range.end_date if time_range else None + + summary_dicts = await self.memory_summary_repo.find_by_group_id( + group_id=user_id, + limit=limit, + start_date=start_date, + end_date=end_date + ) + + summaries = [] + for summary_dict in summary_dicts: + summary = self._dict_to_user_summary(summary_dict, user_id) + if summary: + summaries.append(summary) + + logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}") + return summaries + + except Exception as e: + logger.error(f"Failed to retrieve summaries for user {user_id}: {e}") + raise + \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/habit_detector.py b/api/app/core/memory/analytics/implicit_memory/habit_detector.py new file mode 100644 index 00000000..4f0bcc3e --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/habit_detector.py @@ -0,0 +1,226 @@ +"""Habit Detector for Implicit Memory System + +This module implements the HabitDetector class that specializes in identifying +and ranking behavioral habits from user memory summaries. It provides advanced +habit analysis with confidence scoring, recency weighting, and current vs past +habit distinction. +""" + +import logging +from datetime import datetime, timedelta +from typing import List, Optional + +from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import ( + HabitAnalyzer, +) +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.schemas.implicit_memory_schema import ( + BehaviorHabit, + FrequencyPattern, + UserMemorySummary, +) +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class HabitDetector: + """Detects and ranks behavioral habits from user memory summaries.""" + + def __init__( + self, + db: Session, + llm_model_id: Optional[str] = None + ): + """Initialize the habit detector. + + Args: + db: Database session + llm_model_id: Optional LLM model ID to use for analysis + """ + self.db = db + self.llm_model_id = llm_model_id + self.habit_analyzer = HabitAnalyzer(db, llm_model_id) + + async def detect_habits( + self, + user_id: str, + user_summaries: List[UserMemorySummary], + existing_habits: Optional[List[BehaviorHabit]] = None + ) -> List[BehaviorHabit]: + """Detect behavioral habits from user summaries. + + Args: + user_id: Target user ID + user_summaries: List of user-specific memory summaries + existing_habits: Optional existing habits for consolidation + + Returns: + List of detected and ranked behavioral habits + + Raises: + LLMClientException: If habit analysis fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for user {user_id}") + return existing_habits or [] + + logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries") + + try: + # Use the habit analyzer to extract habits + detected_habits = await self.habit_analyzer.analyze_habits( + user_id=user_id, + user_summaries=user_summaries, + existing_habits=existing_habits + ) + + # Apply advanced ranking and filtering + ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits) + + # Distinguish current vs past habits + categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits) + + logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}") + return categorized_habits + + except LLMClientException: + logger.error(f"Habit detection failed for user {user_id}") + raise + except Exception as e: + logger.error(f"Habit detection failed for user {user_id}: {e}") + raise LLMClientException(f"Habit detection failed: {e}") from e + + def rank_habits_by_confidence_and_recency( + self, + habits: List[BehaviorHabit], + confidence_weight: float = 0.6, + recency_weight: float = 0.4 + ) -> List[BehaviorHabit]: + """Rank habits by confidence level and recency. + + Args: + habits: List of habits to rank + confidence_weight: Weight for confidence score (0.0-1.0) + recency_weight: Weight for recency score (0.0-1.0) + + Returns: + List of habits ranked by combined score + """ + if not habits: + return [] + + logger.info(f"Ranking {len(habits)} habits by confidence and recency") + + def calculate_ranking_score(habit: BehaviorHabit) -> float: + """Calculate combined ranking score for a habit.""" + + # Confidence score (0.0-1.0) - convert from 0-100 scale + confidence_score = habit.confidence_level / 100.0 + + # Recency score (0.0-1.0) + current_time = datetime.now() + days_since_last = (current_time - habit.last_observed).days + + # Exponential decay for recency (habits lose relevance over time) + if days_since_last <= 7: + recency_score = 1.0 # Very recent + elif days_since_last <= 30: + recency_score = 0.8 # Recent + elif days_since_last <= 90: + recency_score = 0.5 # Somewhat recent + elif days_since_last <= 180: + recency_score = 0.3 # Old + else: + recency_score = 0.1 # Very old + + # Frequency pattern bonus + frequency_bonuses = { + FrequencyPattern.DAILY: 0.2, + FrequencyPattern.WEEKLY: 0.15, + FrequencyPattern.MONTHLY: 0.1, + FrequencyPattern.SEASONAL: 0.05, + FrequencyPattern.OCCASIONAL: 0.0, + FrequencyPattern.EVENT_TRIGGERED: 0.05 + } + frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0) + + # Evidence quality bonus + evidence_bonus = min(len(habit.specific_examples) / 10.0, 0.1) # Max 0.1 bonus + + # Current habit bonus + current_bonus = 0.1 if habit.is_current else 0.0 + + # Calculate final score + base_score = (confidence_score * confidence_weight + + recency_score * recency_weight) + + final_score = base_score + frequency_bonus + evidence_bonus + current_bonus + + return min(final_score, 1.0) # Cap at 1.0 + + # Sort habits by ranking score (descending) + ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True) + + logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}") + + return ranked_habits + + def distinguish_current_vs_past_habits( + self, + habits: List[BehaviorHabit], + current_threshold_days: int = 30 + ) -> List[BehaviorHabit]: + """Distinguish between current and past habits based on recency. + + Args: + habits: List of habits to categorize + current_threshold_days: Days threshold for considering a habit current + + Returns: + List of habits with updated is_current status + """ + if not habits: + return [] + + current_time = datetime.now() + cutoff_date = current_time - timedelta(days=current_threshold_days) + + current_habits = [] + past_habits = [] + + for habit in habits: + # Update is_current status based on last observation + if habit.last_observed >= cutoff_date: + # Create updated habit with is_current = True + updated_habit = BehaviorHabit( + habit_description=habit.habit_description, + frequency_pattern=habit.frequency_pattern, + time_context=habit.time_context, + confidence_level=habit.confidence_level, + specific_examples=habit.specific_examples, + first_observed=habit.first_observed, + last_observed=habit.last_observed, + is_current=True + ) + current_habits.append(updated_habit) + else: + # Create updated habit with is_current = False + updated_habit = BehaviorHabit( + habit_description=habit.habit_description, + frequency_pattern=habit.frequency_pattern, + time_context=habit.time_context, + confidence_level=habit.confidence_level, + specific_examples=habit.specific_examples, + first_observed=habit.first_observed, + last_observed=habit.last_observed, + is_current=False + ) + past_habits.append(updated_habit) + + # Return current habits first, then past habits + categorized_habits = current_habits + past_habits + + logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past") + + return categorized_habits \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/llm_client.py b/api/app/core/memory/analytics/implicit_memory/llm_client.py new file mode 100644 index 00000000..f72e49ec --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/llm_client.py @@ -0,0 +1,321 @@ +"""LLM Client Wrapper for Implicit Memory Analysis + +This module provides a specialized LLM client wrapper that integrates with the +MemoryClientFactory to perform implicit memory analysis tasks including preference +extraction, personality dimension analysis, interest categorization, and habit detection. +""" + +import logging +from typing import Any, Dict, List, Optional + +from app.core.memory.analytics.implicit_memory.prompts import ( + get_dimension_analysis_prompt, + get_habit_analysis_prompt, + get_interest_analysis_prompt, + get_preference_analysis_prompt, +) +from app.core.memory.llm_tools.llm_client import LLMClientException +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.schemas.implicit_memory_schema import UserMemorySummary +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +# Response Models for LLM Analysis + +class PreferenceAnalysisResponse(BaseModel): + """Response model for preference analysis.""" + preferences: List[Dict[str, Any]] = Field(default_factory=list) + + +class DimensionAnalysisResponse(BaseModel): + """Response model for dimension analysis.""" + dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + + +class InterestAnalysisResponse(BaseModel): + """Response model for interest analysis.""" + interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + + +class HabitAnalysisResponse(BaseModel): + """Response model for habit analysis.""" + habits: List[Dict[str, Any]] = Field(default_factory=list) + + +class ImplicitMemoryLLMClient: + """LLM client wrapper for implicit memory analysis. + + This class provides a high-level interface for performing LLM-based analysis + of user memory summaries to extract preferences, personality dimensions, + interests, and behavioral habits. + """ + + def __init__(self, db: Session, default_model_id: Optional[str] = None): + """Initialize the LLM client wrapper. + + Args: + db: Database session for accessing model configurations + default_model_id: Default LLM model ID to use if none specified + """ + self.db = db + self.default_model_id = default_model_id + self._client_factory = MemoryClientFactory(db) + + logger.info("ImplicitMemoryLLMClient initialized") + + def _get_llm_client(self, model_id: Optional[str] = None): + """Get LLM client instance. + + Args: + model_id: LLM model ID to use, defaults to default_model_id + + Returns: + LLM client instance + + Raises: + ValueError: If no model ID is provided and no default is set + LLMClientException: If client creation fails + """ + effective_model_id = model_id or self.default_model_id + if not effective_model_id: + raise ValueError("No LLM model ID provided and no default model ID set") + + try: + client = self._client_factory.get_llm_client(effective_model_id) + logger.debug(f"Created LLM client for model: {effective_model_id}") + return client + except Exception as e: + logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}") + raise LLMClientException(f"Failed to create LLM client: {e}") from e + + def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]: + """Prepare user memory summaries for LLM analysis. + + Args: + user_summaries: List of user memory summaries + + Returns: + List of formatted summary dictionaries + """ + formatted_summaries = [] + for summary in user_summaries: + formatted_summary = { + 'summary_id': summary.summary_id, + 'user_content': summary.user_content, + 'timestamp': summary.timestamp.isoformat(), + 'summary_type': summary.summary_type, + 'confidence_score': summary.confidence_score + } + formatted_summaries.append(formatted_summary) + + logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis") + return formatted_summaries + + async def analyze_preferences( + self, + user_summaries: List[UserMemorySummary], + user_id: str, + model_id: Optional[str] = None + ) -> Dict[str, Any]: + """Analyze user preferences from memory summaries. + + Args: + user_summaries: List of user memory summaries to analyze + user_id: Target user ID for analysis + model_id: Optional LLM model ID to use + + Returns: + Dictionary containing extracted preferences + + Raises: + LLMClientException: If LLM analysis fails + ValueError: If input validation fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for preference analysis of user {user_id}") + return {"preferences": []} + + if not user_id: + raise ValueError("User ID is required for preference analysis") + + try: + # Prepare summaries and get prompt + formatted_summaries = self._prepare_summaries_for_analysis(user_summaries) + prompt = get_preference_analysis_prompt(formatted_summaries, user_id) + + # Get LLM client and perform analysis + llm_client = self._get_llm_client(model_id) + + messages = [{"role": "user", "content": prompt}] + + # Use structured output for reliable parsing + response = await llm_client.response_structured( + messages=messages, + response_model=PreferenceAnalysisResponse + ) + + result = response.model_dump() + logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences") + return result + + except Exception as e: + logger.error(f"Preference analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Preference analysis failed: {e}") from e + + async def analyze_dimensions( + self, + user_summaries: List[UserMemorySummary], + user_id: str, + model_id: Optional[str] = None + ) -> Dict[str, Any]: + """Analyze user personality dimensions from memory summaries. + + Args: + user_summaries: List of user memory summaries to analyze + user_id: Target user ID for analysis + model_id: Optional LLM model ID to use + + Returns: + Dictionary containing dimension scores and analysis + + Raises: + LLMClientException: If LLM analysis fails + ValueError: If input validation fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for dimension analysis of user {user_id}") + return {"dimensions": {}} + + if not user_id: + raise ValueError("User ID is required for dimension analysis") + + try: + # Prepare summaries and get prompt + formatted_summaries = self._prepare_summaries_for_analysis(user_summaries) + prompt = get_dimension_analysis_prompt(formatted_summaries, user_id) + + # Get LLM client and perform analysis + llm_client = self._get_llm_client(model_id) + + messages = [{"role": "user", "content": prompt}] + + # Use structured output for reliable parsing + response = await llm_client.response_structured( + messages=messages, + response_model=DimensionAnalysisResponse + ) + + result = response.model_dump() + dimensions = result.get('dimensions', {}) + logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}") + return result + + except Exception as e: + logger.error(f"Dimension analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Dimension analysis failed: {e}") from e + + async def analyze_interests( + self, + user_summaries: List[UserMemorySummary], + user_id: str, + model_id: Optional[str] = None + ) -> Dict[str, Any]: + """Analyze user interest distribution from memory summaries. + + Args: + user_summaries: List of user memory summaries to analyze + user_id: Target user ID for analysis + model_id: Optional LLM model ID to use + + Returns: + Dictionary containing interest area distribution + + Raises: + LLMClientException: If LLM analysis fails + ValueError: If input validation fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for interest analysis of user {user_id}") + return {"interest_distribution": {}} + + if not user_id: + raise ValueError("User ID is required for interest analysis") + + try: + # Prepare summaries and get prompt + formatted_summaries = self._prepare_summaries_for_analysis(user_summaries) + prompt = get_interest_analysis_prompt(formatted_summaries, user_id) + + # Get LLM client and perform analysis + llm_client = self._get_llm_client(model_id) + + messages = [{"role": "user", "content": prompt}] + + # Use structured output for reliable parsing + response = await llm_client.response_structured( + messages=messages, + response_model=InterestAnalysisResponse + ) + + result = response.model_dump() + interest_dist = result.get('interest_distribution', {}) + logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}") + return result + + except Exception as e: + logger.error(f"Interest analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Interest analysis failed: {e}") from e + + async def analyze_habits( + self, + user_summaries: List[UserMemorySummary], + user_id: str, + model_id: Optional[str] = None + ) -> Dict[str, Any]: + """Analyze user behavioral habits from memory summaries. + + Args: + user_summaries: List of user memory summaries to analyze + user_id: Target user ID for analysis + model_id: Optional LLM model ID to use + + Returns: + Dictionary containing identified behavioral habits + + Raises: + LLMClientException: If LLM analysis fails + ValueError: If input validation fails + """ + if not user_summaries: + logger.warning(f"No summaries provided for habit analysis of user {user_id}") + return {"habits": []} + + if not user_id: + raise ValueError("User ID is required for habit analysis") + + try: + # Prepare summaries and get prompt + formatted_summaries = self._prepare_summaries_for_analysis(user_summaries) + prompt = get_habit_analysis_prompt(formatted_summaries, user_id) + + # Get LLM client and perform analysis + llm_client = self._get_llm_client(model_id) + + messages = [{"role": "user", "content": prompt}] + + # Use structured output for reliable parsing + response = await llm_client.response_structured( + messages=messages, + response_model=HabitAnalysisResponse + ) + + result = response.model_dump() + logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits") + return result + + except Exception as e: + logger.error(f"Habit analysis failed for user {user_id}: {e}") + raise LLMClientException(f"Habit analysis failed: {e}") from e \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/prompts.py b/api/app/core/memory/analytics/implicit_memory/prompts.py new file mode 100644 index 00000000..292d3b59 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/prompts.py @@ -0,0 +1,69 @@ +"""LLM Prompt Templates for Implicit Memory Analysis + +This module contains prompt rendering functions for analyzing user memory summaries +to extract preferences, personality dimensions, interests, and behavioral habits. +""" + +import os +from typing import Any, Dict, List + +from jinja2 import Environment, FileSystemLoader + +# Setup Jinja2 environment +current_dir = os.path.dirname(os.path.abspath(__file__)) +prompt_dir = os.path.join(current_dir, "prompts") +prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) + + +def _render_template(template_name: str, **kwargs) -> str: + """Helper function to render Jinja2 templates.""" + template = prompt_env.get_template(template_name) + return template.render(**kwargs) + + +def get_preference_analysis_prompt( + memory_summaries: List[Dict[str, Any]], + user_id: str +) -> str: + """Get formatted preference analysis prompt using Jinja2 template.""" + return _render_template( + "preference_analysis.jinja2", + memory_summaries=memory_summaries, + user_id=user_id + ) + + +def get_dimension_analysis_prompt( + memory_summaries: List[Dict[str, Any]], + user_id: str +) -> str: + """Get formatted dimension analysis prompt using Jinja2 template.""" + return _render_template( + "dimension_analysis.jinja2", + memory_summaries=memory_summaries, + user_id=user_id + ) + + +def get_interest_analysis_prompt( + memory_summaries: List[Dict[str, Any]], + user_id: str +) -> str: + """Get formatted interest analysis prompt using Jinja2 template.""" + return _render_template( + "interest_analysis.jinja2", + memory_summaries=memory_summaries, + user_id=user_id + ) + + +def get_habit_analysis_prompt( + memory_summaries: List[Dict[str, Any]], + user_id: str +) -> str: + """Get formatted habit analysis prompt using Jinja2 template.""" + return _render_template( + "habit_analysis.jinja2", + memory_summaries=memory_summaries, + user_id=user_id + ) \ No newline at end of file diff --git a/api/app/core/memory/analytics/implicit_memory/prompts/dimension_analysis.jinja2 b/api/app/core/memory/analytics/implicit_memory/prompts/dimension_analysis.jinja2 new file mode 100644 index 00000000..afdcfda1 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/prompts/dimension_analysis.jinja2 @@ -0,0 +1,41 @@ +You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions. + +## Memory Summaries +{% for summary in memory_summaries %} +Summary {{ loop.index }}: +{{ summary.content or summary.user_content or '' }} +--- +{% endfor %} + +## Target User ID +{{ user_id }} + +## Dimensions to Analyze +1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas +2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation +3. **Technology** (0-100%): Technical discussions, tool usage, programming interests +4. **Literature** (0-100%): Reading habits, writing style, literary references + +## Instructions +1. Analyze the user's content for each dimension +2. Calculate percentage scores (0-100%) + +## Output Format +{ + "dimensions": { + "creativity": {"percentage": 0-100}, + "aesthetic": {"percentage": 0-100}, + "technology": {"percentage": 0-100}, + "literature": {"percentage": 0-100} + } +} + +## Example +{ + "dimensions": { + "creativity": {"percentage": 75}, + "aesthetic": {"percentage": 45}, + "technology": {"percentage": 60}, + "literature": {"percentage": 30} + } +} diff --git a/api/app/core/memory/analytics/implicit_memory/prompts/habit_analysis.jinja2 b/api/app/core/memory/analytics/implicit_memory/prompts/habit_analysis.jinja2 new file mode 100644 index 00000000..7e78ee36 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/prompts/habit_analysis.jinja2 @@ -0,0 +1,70 @@ +You are an expert at identifying behavioral patterns and habits from memory summaries. + +## Memory Summaries +{% for summary in memory_summaries %} +Summary {{ loop.index }}: +{{ summary.content or summary.user_content or '' }} +--- +{% endfor %} + +## Target User ID +{{ user_id }} + +## Instructions +1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER +2. Focus on specific, concrete habits with temporal patterns +3. For each habit, provide: + - habit_description: Clear, specific description + - frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered" + - time_context: When it typically happens + - confidence_level: "high", "medium", "low" + - supporting_summaries: References to evidence + - specific_examples: Concrete examples from summaries + - is_current: true if current habit, false if past habit +4. Only include habits with medium or high confidence +5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.** + +## Output Format +{ + "habits": [ + { + "habit_description": "string", + "frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered", + "time_context": "string", + "confidence_level": "high|medium|low", + "supporting_summaries": ["id1", "id2"], + "specific_examples": ["example1", "example2"], + "is_current": true|false + } + ] +} + +## Example (English input → English output) +{ + "habits": [ + { + "habit_description": "drinks coffee every morning", + "frequency_pattern": "daily", + "time_context": "morning routine", + "confidence_level": "high", + "supporting_summaries": ["s1", "s2"], + "specific_examples": ["needs coffee to start the day"], + "is_current": true + } + ] +} + +## Example (Chinese input → Chinese output) +{ + "habits": [ + { + "habit_description": "每天早上喝咖啡", + "frequency_pattern": "daily", + "time_context": "早晨日常", + "confidence_level": "high", + "supporting_summaries": ["s1", "s2"], + "specific_examples": ["需要咖啡来开始一天"], + "is_current": true + } + ] +} diff --git a/api/app/core/memory/analytics/implicit_memory/prompts/interest_analysis.jinja2 b/api/app/core/memory/analytics/implicit_memory/prompts/interest_analysis.jinja2 new file mode 100644 index 00000000..4b54190f --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/prompts/interest_analysis.jinja2 @@ -0,0 +1,54 @@ +You are an expert at analyzing user interests from memory summaries. + +## Memory Summaries +{% for summary in memory_summaries %} +Summary {{ loop.index }}: +{{ summary.content or summary.user_content or '' }} +--- +{% endfor %} + +## Target User ID +{{ user_id }} + +## Interest Categories +1. **Tech**: Programming, technology, software tools, hardware +2. **Lifestyle**: Daily routines, health, hobbies, social activities +3. **Music**: Music preferences, instruments, concerts +4. **Art**: Visual arts, creative projects, design, aesthetics + +## Instructions +1. Categorize the user's interests into the four areas +2. Calculate percentage distribution (must total 100%) +3. Provide specific evidence for each interest area +4. Use "increasing", "decreasing", or "stable" for trending direction +5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.** + +## Output Format +{ + "interest_distribution": { + "tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}, + "lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}, + "music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}, + "art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"} + } +} + +## Example (English input → English output) +{ + "interest_distribution": { + "tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"}, + "lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"}, + "music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"}, + "art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"} + } +} + +## Example (Chinese input → Chinese output) +{ + "interest_distribution": { + "tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"}, + "lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"}, + "music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"}, + "art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"} + } +} diff --git a/api/app/core/memory/analytics/implicit_memory/prompts/preference_analysis.jinja2 b/api/app/core/memory/analytics/implicit_memory/prompts/preference_analysis.jinja2 new file mode 100644 index 00000000..fd7c8436 --- /dev/null +++ b/api/app/core/memory/analytics/implicit_memory/prompts/preference_analysis.jinja2 @@ -0,0 +1,47 @@ +You are an expert at analyzing user memory summaries to identify implicit preferences. + +## Memory Summaries +{% for summary in memory_summaries %} +Summary {{ loop.index }}: +{{ summary.content or summary.user_content or '' }} +--- +{% endfor %} + +## Target User ID +{{ user_id }} + +## Instructions +1. Focus ONLY on the specified user's preferences +2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他" +3. DO NOT use long phrases - use short nouns or noun phrases +4. Only include preferences with confidence_score >= 0.3 +5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.** + +## Output Format +{ + "preferences": [ + { + "tag_name": "short tag", + "confidence_score": 0.0-1.0, + "supporting_evidence": ["evidence1", "evidence2"], + "context_details": "brief context", + "category": "category or null" + } + ] +} + +## Example (Chinese input → Chinese output) +{ + "preferences": [ + {"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"}, + {"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"} + ] +} + +## Example (English input → English output) +{ + "preferences": [ + {"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"}, + {"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"} + ] +} diff --git a/api/app/core/memory/analytics/memory_insight.py b/api/app/core/memory/analytics/memory_insight.py deleted file mode 100644 index 39746e58..00000000 --- a/api/app/core/memory/analytics/memory_insight.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -This module provides the MemoryInsight class for analyzing user memory data. - -MemoryInsight 是一个工具类,提供基础的数据获取和分析功能: -- get_domain_distribution(): 获取记忆领域分布 -- get_active_periods(): 获取活跃时段 -- get_social_connections(): 获取社交关联 - -业务逻辑(如生成洞察报告)应该在服务层(user_memory_service.py)中实现。 - -This script can be executed directly to test the memory insight generation for a test user. -""" - -import asyncio -import json -import os -import sys -from collections import Counter -from datetime import datetime - -# To run this script directly, we need to add the src directory to the Python path -# to resolve the inconsistent imports in other modules. -src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -if src_path not in sys.path: - sys.path.insert(0, src_path) - -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService -from pydantic import BaseModel, Field - -#TODO: Fix this - -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") - -# 定义用于LLM结构化输出的Pydantic模型 -class TagClassification(BaseModel): - """ - Represents the classification of a tag into a specific domain. - """ - - domain: str = Field( - ..., - description="The domain the tag belongs to, chosen from the predefined list.", - examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"], - ) - -class InsightReport(BaseModel): - """ - Represents the final insight report generated by the LLM. - """ - - report: str = Field( - ..., - description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.", - ) - - -class MemoryInsight: - """ - Provides insights into user memories by analyzing various aspects of their data. - """ - - def __init__(self, user_id: str): - self.user_id = user_id - self.neo4j_connector = Neo4jConnector() - - # Get config_id using get_end_user_connected_config - with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - connected_config = get_end_user_connected_config(user_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - - async def close(self): - """关闭数据库连接。""" - await self.neo4j_connector.close() - - async def get_domain_distribution(self) -> dict[str, float]: - """ - Calculates the distribution of memory domains based on hot tags. - """ - hot_tags = await get_hot_memory_tags(self.user_id) - if not hot_tags: - return {} - - domain_counts = Counter() - for tag, _ in hot_tags: - prompt = f"""请将以下标签归类到最合适的领域中。 - -可选领域及其关键词: -- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等 -- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等 -- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等 -- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等 -- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等 -- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等 -- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等 -- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等 -- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等 -- 其他:确实无法归入以上任何类别的内容 - -标签: {tag} - -分析步骤: -1. 仔细理解标签的核心含义和使用场景 -2. 对比各个领域的关键词,找到最匹配的领域 -3. 特别注意: - - 历史、科学、文化等知识性内容应归类为"学习" - - 学校、课程、考试等正式教育场景应归类为"教育" - - 只有在标签完全不属于上述9个具体领域时,才选择"其他" -4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他" - -请直接返回最合适的领域名称。""" - messages = [ - {"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"}, - {"role": "user", "content": prompt} - ] - # 直接调用并等待结果 - classification = await self.llm_client.response_structured( - messages=messages, - response_model=TagClassification, - ) - if classification and hasattr(classification, 'domain') and classification.domain: - domain_counts[classification.domain] += 1 - - total_tags = sum(domain_counts.values()) - if total_tags == 0: - return {} - - domain_distribution = { - domain: count / total_tags for domain, count in domain_counts.items() - } - return dict( - sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True) - ) - - async def get_active_periods(self) -> list[int]: - """ - Identifies the top 2 most active months for the user. - Only returns months if there is valid and diverse time data. - - This method checks if the time data represents real user memory timestamps - rather than auto-generated system timestamps by verifying: - 1. Time data exists and is parseable - 2. Time data is distributed across multiple months (not concentrated in 1-2 months) - """ - query = f""" - MATCH (d:Dialogue) - WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> '' - RETURN d.created_at AS creation_time - """ - records = await self.neo4j_connector.execute_query(query) - - if not records: - return [] - - month_counts = Counter() - valid_dates_count = 0 - for record in records: - creation_time_str = record.get("creation_time") - if not creation_time_str: - continue - try: - # 尝试解析时间字符串 - dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00")) - month_counts[dt_object.month] += 1 - valid_dates_count += 1 - except (ValueError, TypeError, AttributeError): - # 如果解析失败,跳过这条记录 - continue - - # 如果没有有效的时间数据,返回空列表 - if not month_counts or valid_dates_count == 0: - return [] - - # 检查时间分布是否过于集中(可能是批量导入的数据) - # 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间 - unique_months = len(month_counts) - if unique_months <= 2: - # 只有1-2个月有数据,很可能是批量导入 - most_common_count = month_counts.most_common(1)[0][1] - if most_common_count / valid_dates_count > 0.8: - # 超过80%集中在一个月,认为是系统时间戳 - return [] - - # 如果时间分布较为分散(3个月以上),认为是真实时间数据 - if unique_months >= 3: - most_common_months = month_counts.most_common(2) - return [month for month, _ in most_common_months] - - # 2个月的情况,检查是否分布均匀 - if unique_months == 2: - counts = list(month_counts.values()) - # 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据 - ratio = min(counts) / max(counts) - if ratio > 0.3: - most_common_months = month_counts.most_common(2) - return [month for month, _ in most_common_months] - - # 其他情况返回空列表 - return [] - - async def get_social_connections(self) -> dict | None: - """ - Finds the user with whom the most memories are shared. - 使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。 - """ - # 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆 - query = f""" - MATCH (c1:Chunk {{group_id: '{self.user_id}'}}) - OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) - OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) - WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL - WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements - WHERE common_statements > 0 - RETURN other_user_id, common_statements - ORDER BY common_statements DESC - LIMIT 1 - """ - records = await self.neo4j_connector.execute_query(query) - if not records or not records[0].get("other_user_id"): - return None - - most_connected_user = records[0]["other_user_id"] - common_memories_count = records[0]["common_statements"] - - # 使用 Chunk 的时间范围 - time_range_query = f""" - MATCH (c:Chunk) - WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}'] - RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time - """ - time_records = await self.neo4j_connector.execute_query(time_range_query) - start_year, end_year = "N/A", "N/A" - if time_records and time_records[0]["start_time"]: - start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year - end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year - - return { - "user_id": most_connected_user, - "common_memories_count": common_memories_count, - "time_range": f"{start_year}-{end_year}", - } - - async def close(self): - """ - Closes the database connection. - """ - await self.neo4j_connector.close() - - -async def main(): - """ - Initializes and runs the memory insight analysis for a test user. - """ - # 默认从环境变量读取 - test_user_id = DEFAULT_GROUP_ID - print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n") - - try: - # 使用服务层函数生成报告 - from app.services.user_memory_service import analytics_memory_insight_report - - result = await analytics_memory_insight_report(end_user_id=test_user_id) - report = result.get("report", "") - - print("--- 记忆洞察报告 ---") - print(report) - print("---------------------") - - # 将结果写入统一的 User-Dashboard.json,使用全局配置路径 - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_dir = settings.MEMORY_OUTPUT_DIR - try: - os.makedirs(output_dir, exist_ok=True) - except Exception: - pass - dashboard_path = os.path.join(output_dir, "User-Dashboard.json") - existing = {} - if os.path.exists(dashboard_path): - with open(dashboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["memory_insight"] = { - "group_id": test_user_id, - "report": report - } - with open(dashboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {dashboard_path} -> memory_insight") - except Exception as e: - print(f"写入 User-Dashboard.json 失败: {e}") - except Exception as e: - print(f"生成报告时出错: {e}") - - -if __name__ == "__main__": - # This setup allows running the async main function - if sys.platform.startswith('win') and sys.version_info >= (3, 8): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/api/app/core/memory/analytics/user_summary.py b/api/app/core/memory/analytics/user_summary.py deleted file mode 100644 index f0283993..00000000 --- a/api/app/core/memory/analytics/user_summary.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Generate a concise "关于我" style user summary using data from Neo4j -and the existing LLM configuration (mirrors hot_memory_tags.py setup). - -Usage: - python -m analytics.user_summary --user_id -""" - -import asyncio -import json -import os -import sys -from dataclasses import dataclass -from typing import List, Tuple - -# Ensure absolute imports work whether executed directly or via module -try: - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) - src_path = os.path.join(project_root, 'src') - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -except Exception: - pass - -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - -#TODO: Fix this - -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") - - -@dataclass -class StatementRecord: - statement: str - created_at: str | None - - -class UserSummary: - """Builds a textual user summary for a given user/group id.""" - - def __init__(self, user_id: str): - self.user_id = user_id - self.connector = Neo4jConnector() - - # Get config_id using get_end_user_connected_config - with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - connected_config = get_end_user_connected_config(user_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(DEFAULT_LLM_ID) - - async def close(self): - await self.connector.close() - - async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service - """Fetch recent statements authored by the user/group for context.""" - query = ( - "MATCH (s:Statement) " - "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " - "RETURN s.statement AS statement, s.created_at AS created_at " - "ORDER BY created_at DESC LIMIT $limit" - ) - rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) - records: List[StatementRecord] = [] - for r in rows: - try: - records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at"))) - except Exception: - continue - return records - - async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]: - """Reuse hot tag logic to get meaningful entities and their frequencies.""" - # get_hot_memory_tags internally filters out non-meaningful nouns with LLM - return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service - - -async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless - """ - 生成用户摘要的便捷函数 - - Args: - user_id: 可选的用户ID - - Returns: - 用户摘要字符串 - """ - # 导入服务层函数 - from app.services.user_memory_service import analytics_user_summary - - # 调用服务层函数 - result = await analytics_user_summary(user_id) - return result.get("summary", "") - - -if __name__ == "__main__": - print("开始生成用户摘要…") - try: - # 直接使用 runtime.json 中的 group_id - summary = asyncio.run(generate_user_summary()) - print("\n— 用户摘要 —\n") - print(summary) - - # 将结果写入统一的 User-Dashboard.json - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_dir = settings.MEMORY_OUTPUT_DIR - try: - os.makedirs(output_dir, exist_ok=True) - except Exception: - pass - dashboard_path = os.path.join(output_dir, "User-Dashboard.json") - existing = {} - if os.path.exists(dashboard_path): - with open(dashboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["user_summary"] = { - "group_id": DEFAULT_GROUP_ID, - "summary": summary - } - with open(dashboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {dashboard_path} -> user_summary") - except Exception as e: - print(f"写入 User-Dashboard.json 失败: {e}") - except Exception as e: - print(f"生成摘要失败: {e}") - print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 4d4221a3..39d618fc 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -37,12 +37,20 @@ def parse_historical_datetime(v): 此函数手动解析 ISO 8601 格式的日期字符串,支持1-4位年份 Args: - v: 日期值(可以是 None、datetime 对象或字符串) + v: 日期值(可以是 None、datetime 对象、Neo4j DateTime 对象或字符串) Returns: datetime 对象或 None """ - if v is None or isinstance(v, datetime): + if v is None: + return v + + # 处理 Neo4j DateTime 对象 + if hasattr(v, 'to_native'): + return v.to_native() + + # 处理 Python datetime 对象 + if isinstance(v, datetime): return v if isinstance(v, str): @@ -397,6 +405,10 @@ class ExtractedEntityNode(Node): statement_id: str = Field(..., description="Statement this entity was extracted from") entity_type: str = Field(..., description="Type of the entity") description: str = Field(..., description="Entity description") + example: str = Field( + default="", + description="A concise example (around 20 characters) to help understand the entity" + ) aliases: List[str] = Field( default_factory=list, description="Entity aliases - alternative names for this entity" @@ -433,6 +445,12 @@ class ExtractedEntityNode(Node): description="Total number of times this node has been accessed" ) + # Explicit Memory Classification + is_explicit_memory: bool = Field( + default=False, + description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" + ) + @field_validator('aliases', mode='before') @classmethod def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 @@ -466,6 +484,8 @@ class MemorySummaryNode(Node): dialog_id: ID of the parent dialog chunk_ids: List of chunk IDs used to generate this summary content: Summary text content + name: Title/name of the memory summary (generated by LLM, used as title in API) + memory_type: Type/category of the episodic memory (e.g., Conversation, Project/Work, Learning, Decision, Important Event) summary_embedding: Optional embedding vector for the summary metadata: Additional metadata for the summary config_id: Configuration ID used to process this summary @@ -484,6 +504,7 @@ class MemorySummaryNode(Node): dialog_id: str = Field(..., description="ID of the parent dialog") chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary") content: str = Field(..., description="Summary text content") + memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") diff --git a/api/app/core/memory/models/triplet_models.py b/api/app/core/memory/models/triplet_models.py index b0a062a3..df7ee14b 100644 --- a/api/app/core/memory/models/triplet_models.py +++ b/api/app/core/memory/models/triplet_models.py @@ -38,10 +38,20 @@ class Entity(BaseModel): name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name") type: str = Field(..., description="Type/category of the entity") description: str = Field(..., description="Description of the entity") + example: str = Field( + default="", + description="A concise example (around 20 characters) to help understand the entity" + ) aliases: List[str] = Field( default_factory=list, description="Alternative names for this entity (abbreviations, full names, translations, etc.)" ) + + # Explicit Memory Classification + is_explicit_memory: bool = Field( + default=False, + description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" + ) class Triplet(BaseModel): diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 7c2ed5f4..75aaa7df 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -42,7 +42,6 @@ from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_ ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, - embedding_generation_all, generate_entity_embeddings_from_triplets, ) @@ -179,7 +178,7 @@ class ExtractionOrchestrator: for dialog in dialog_data_list: for chunk in dialog.chunks: all_statements_list.extend(chunk.statements) - total_statements = len(all_statements_list) + len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") @@ -201,9 +200,9 @@ class ExtractionOrchestrator: all_entities_list.extend(triplet_info.entities) all_triplets_list.extend(triplet_info.triplets) - total_entities = len(all_entities_list) - total_triplets = len(all_triplets_list) - total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps) + len(all_entities_list) + len(all_triplets_list) + sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") @@ -385,7 +384,7 @@ class ExtractionOrchestrator: # 用于跟踪已完成的陈述句数量 completed_statements = 0 - total_statements = len(all_statements) + len(all_statements) # 全局并行处理所有陈述句 async def extract_for_statement(stmt_data, stmt_index): @@ -497,7 +496,7 @@ class ExtractionOrchestrator: # 用于跟踪已完成的时间提取数量 completed_temporal = 0 - total_temporal_statements = len(all_statements) + len(all_statements) # 全局并行处理所有陈述句 async def extract_for_statement(stmt_data, stmt_index): @@ -1082,10 +1081,12 @@ class ExtractionOrchestrator: statement_id=statement.id, # 添加必需的 statement_id 字段 entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type description=getattr(entity, 'description', ''), # 添加必需的 description 字段 + example=getattr(entity, 'example', ''), # 新增:传递示例字段 fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), + is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 group_id=dialog_data.group_id, user_id=dialog_data.user_id, apply_id=dialog_data.apply_id, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 70c1ceb3..c72b9a1f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -59,13 +59,28 @@ async def _process_chunk_summary( ) summary_text = structured.summary.strip() + # Generate title and type for the summary + title = None + episodic_type = None + try: + from app.services.user_memory_service import UserMemoryService + title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary( + content=summary_text, + end_user_id=dialog.group_id + ) + logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}") + except Exception as e: + logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}") + # Continue without title and type + # Embed the summary embedding = (await embedder.response([summary_text]))[0] # Build node per chunk + # Note: title is stored in the 'name' field, type is stored in 'memory_type' field node = MemorySummaryNode( id=uuid4().hex, - name=f"MemorySummaryChunk_{chunk.id}", + name=title if title else f"MemorySummaryChunk_{chunk.id}", group_id=dialog.group_id, user_id=dialog.user_id, apply_id=dialog.apply_id, @@ -75,6 +90,7 @@ async def _process_chunk_summary( dialog_id=dialog.id, chunk_ids=[chunk.id], content=summary_text, + memory_type=episodic_type, summary_embedding=embedding, metadata={"ref_id": dialog.ref_id}, config_id=dialog.config_id, # 添加 config_id diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index acc2a717..729a5542 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -246,7 +246,7 @@ class AccessHistoryManager: if not node_data: return ConsistencyCheckResult.CONSISTENT, None - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] last_access_time = node_data.get('last_access_time') access_count = node_data.get('access_count', 0) activation_value = node_data.get('activation_value') @@ -409,7 +409,7 @@ class AccessHistoryManager: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") return False - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) # 准备修复数据 @@ -530,7 +530,7 @@ class AccessHistoryManager: Returns: Dict[str, Any]: 更新数据,包含所有需要更新的字段 """ - access_history = node_data.get('access_history', []) + access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) # 追加新的访问时间 diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py index 5e1e35da..f1802166 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py @@ -247,6 +247,9 @@ class ForgettingStrategy: entity_activation = entity_node['entity_activation'] entity_importance = entity_node['entity_importance'] + # 获取 group_id(从 statement 或 entity 节点) + group_id = statement_node.get('group_id') or entity_node.get('group_id') + # 生成摘要内容 summary_text = await self._generate_summary( statement_text=statement_text, @@ -256,6 +259,19 @@ class ForgettingStrategy: db=db ) + # 生成标题和类型(使用LLM) + from app.services.user_memory_service import UserMemoryService + try: + title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary( + content=summary_text, + end_user_id=group_id + ) + logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}") + except Exception as e: + logger.error(f"生成标题和类型失败,使用默认值: {str(e)}") + title = "未命名" + episodic_type = "其他" + # 计算继承的激活值和重要性(取较高值) inherited_activation = max(statement_activation, entity_activation) inherited_importance = max(statement_importance, entity_importance) @@ -268,9 +284,6 @@ class ForgettingStrategy: import uuid summary_id = f"summary_{uuid.uuid4().hex[:16]}" - # 获取 group_id(从 statement 或 entity 节点) - group_id = statement_node.get('group_id') or entity_node.get('group_id') - # 使用事务创建 MemorySummary 并删除原节点 async def merge_transaction(tx, **params): """事务函数:创建摘要节点并删除原节点""" @@ -287,6 +300,8 @@ class ForgettingStrategy: CREATE (ms:MemorySummary { id: $summary_id, summary: $summary_text, + name: $title, + memory_type: $episodic_type, original_statement_id: $statement_id, original_entity_id: $entity_id, activation_value: $inherited_activation, @@ -386,6 +401,8 @@ class ForgettingStrategy: params = { 'summary_id': summary_id, 'summary_text': summary_text, + 'title': title, + 'episodic_type': episodic_type, 'statement_id': statement_id, 'entity_id': entity_id, 'inherited_activation': inherited_activation, diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 842f3c82..50593e49 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -386,3 +386,26 @@ async def render_memory_insight_prompt( }) return rendered_prompt + + +async def render_episodic_title_and_type_prompt(content: str) -> str: + """ + Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template. + + Args: + content: The content of the episodic memory summary to analyze + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("episodic_type_classification.jinja2") + rendered_prompt = template.render(content=content) + + # 记录渲染结果到提示日志 + log_prompt_rendering('episodic title and type classification', rendered_prompt) + # 可选:记录模板渲染信息 + log_template_rendering('episodic_type_classification.jinja2', { + 'content_len': len(content) if content else 0 + }) + + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 new file mode 100644 index 00000000..fa382ec7 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 @@ -0,0 +1,57 @@ +=== Task === +Generate a concise title and classify the episodic memory into the most appropriate category. + +=== Requirements === +- Extract a clear, concise title (10-20 characters) that captures the core content +- Classify into exactly one category based on the primary theme +- Be specific and avoid ambiguity +- Output must be valid JSON conforming to the schema below + +=== Input === +{{ content }} + +=== Category Definitions === + +1. **conversation**: Daily communication, chat, discussion, and social interactions + - Keywords: chat, communication, discussion, dialogue, exchange + +2. **project_work**: Work-related tasks, projects, meetings, and collaboration + - Keywords: project, task, work, meeting, collaboration, business, client + +3. **learning**: Acquiring new knowledge, skill development, reading, and research + - Keywords: learning, reading, research, knowledge, skill, course, training + +4. **decision**: Making important decisions, choices, and planning + - Keywords: decision, choice, planning, consideration, evaluation, weighing + +5. **important_event**: Major events, milestones, and special experiences + - Keywords: important, major, milestone, special, memorable, celebration + +=== Analysis Steps === +1. Read the episodic memory content carefully +2. Identify the core theme and context +3. Extract a concise title +4. Compare against category definitions and keywords +5. Select the best matching category +6. If multiple categories apply, choose the primary one + +=== Output Schema === +**CRITICAL JSON FORMATTING REQUIREMENTS:** +1. Use only standard ASCII double quotes (") for JSON structure +2. Escape any quotation marks within string values using backslashes (\") +3. Ensure all JSON strings are properly closed and comma-separated +4. Do not include line breaks within JSON string values + +Return only a JSON object with title and type fields: +{ + "title": "Generated title here", + "type": "Category type here" +} + +The type field must be exactly one of: +- conversation +- project_work +- learning +- decision +- important_event + diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 337b5d4f..03691a04 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -12,7 +12,34 @@ Extract entities and knowledge triplets from the given statement. ===Guidelines=== **Entity Extraction:** -- Extract entities with their types, context-independent descriptions, and aliases +- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification +- **Semantic Memory Classification (is_explicit_memory):** + * Set to `true` if the entity represents **explicit/semantic memory**: + - **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主" + - **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论" + - **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口" + - **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律" + - **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学" + - **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法" + - **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库" + * Set to `false` for: + - **People:** "John Smith", "Dr. Wang", "张明", "王博士" + - **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学" + - **Locations:** "Beijing", "Central Park", "北京", "中央公园" + - **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议" + - **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋" +- **Example Generation (IMPORTANT for semantic memory entities):** + * For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept + * The example should be: + - **Specific and concrete**: Use real-world scenarios or applications + - **Brief**: Around 20 characters (can be slightly longer if needed for clarity) + - **In the same language as the entity name** + * Examples: + - Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗" + - Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed" + - Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy" + - Entity: "人工智能" → example: "如:智能客服、自动驾驶" + * For non-semantic entities (`is_explicit_memory=false`), the example field can be empty - **Aliases Extraction (Important):** * **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text** * **DO NOT translate or add aliases in different languages** @@ -84,21 +111,27 @@ Output: "name": "I", "type": "Person", "description": "The user", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 1, "name": "Paris", "type": "Location", "description": "Capital city of France", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 2, "name": "Louvre", "type": "Location", "description": "World-famous museum located in Paris", - "aliases": ["Louvre Museum"] + "example": "", + "aliases": ["Louvre Museum"], + "is_explicit_memory": false } ] } @@ -130,21 +163,27 @@ Output: "name": "John Smith", "type": "Person", "description": "Individual person name", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 1, "name": "Google", "type": "Organization", "description": "American technology company", - "aliases": ["Google LLC", "Alphabet Inc."] + "example": "", + "aliases": ["Google LLC", "Alphabet Inc."], + "is_explicit_memory": false }, { "entity_idx": 2, "name": "AI product development", - "type": "WorkRole", + "type": "Concept", "description": "Artificial intelligence product development work", - "aliases": [] + "example": "e.g., developing chatbots, recommendation systems", + "aliases": [], + "is_explicit_memory": true } ] } @@ -176,21 +215,27 @@ Output: "name": "我", "type": "Person", "description": "用户本人", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 1, "name": "巴黎", "type": "Location", "description": "法国首都城市", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 2, "name": "卢浮宫", "type": "Location", "description": "位于巴黎的世界著名博物馆", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false } ] } @@ -222,21 +267,27 @@ Output: "name": "张明", "type": "Person", "description": "个人姓名", - "aliases": [] + "example": "", + "aliases": [], + "is_explicit_memory": false }, { "entity_idx": 1, "name": "腾讯", "type": "Organization", "description": "中国科技公司", - "aliases": ["腾讯控股", "腾讯公司"] + "example": "", + "aliases": ["腾讯控股", "腾讯公司"], + "is_explicit_memory": false }, { "entity_idx": 2, "name": "AI产品开发", - "type": "WorkRole", + "type": "Concept", "description": "人工智能产品研发工作", - "aliases": [] + "example": "如:开发智能客服机器人、推荐系统", + "aliases": [], + "is_explicit_memory": true } ] } @@ -251,7 +302,9 @@ Output: "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", - "aliases": ["Camera Tripod"] + "example": "", + "aliases": ["Camera Tripod"], + "is_explicit_memory": false } ] } @@ -266,7 +319,9 @@ Output: "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", - "aliases": ["相机三脚架"] + "example": "", + "aliases": ["相机三脚架"], + "is_explicit_memory": false } ] } diff --git a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 index 373ab31e..2f452c53 100644 --- a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 @@ -85,33 +85,21 @@ Example Output: ===End of Example=== -===Reflection Process=== +===Internal Quality Checks (DO NOT OUTPUT)=== -After generating the profile, perform the following self-review steps: +Before generating your final output, internally verify: +1. All content is grounded in provided data (no fabrication) +2. Format follows the specified structure with correct headers +3. Tone is objective, third-person, and neutral +4. All four sections are complete and within character limits -**Step 1: Data Grounding Check** -- Verify all statements are supported by the provided entities and statements -- Ensure no fabricated or speculated information is included -- Confirm all claims can be traced back to the input data - -**Step 2: Format Compliance** -- Verify each section follows the specified format with section headers -- Check character count limits for each section -- Ensure proper use of section markers (【】) - -**Step 3: Tone and Style Review** -- Confirm objective third-person perspective is maintained -- Check for excessive adjectives or empty phrases -- Verify neutral and restrained tone throughout - -**Step 4: Completeness Check** -- Ensure all four sections are present and complete -- Verify each section addresses its specific focus area -- Confirm the one-sentence summary effectively captures the user's essence +**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.** ===Output Requirements=== +**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.** + **LANGUAGE REQUIREMENT:** - The output language should ALWAYS be Chinese (Simplified) - All section content must be in Chinese @@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps: - Content follows immediately after the header - Sections are separated by blank lines - Strictly adhere to character limits for each section +- **DO NOT include any text after the 【一句话总结】 section** +- **DO NOT output reflection steps, self-review, or verification notes** diff --git a/api/app/core/rag/app/naive.py b/api/app/core/rag/app/naive.py index 23f0c4ba..2b8d0e50 100644 --- a/api/app/core/rag/app/naive.py +++ b/api/app/core/rag/app/naive.py @@ -64,8 +64,8 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese" def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None, **kwargs): textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown") - app_id = os.environ.get("TEXTLN_APP_ID", "fa3f24380683ad53e6c620c0f0878a09") - secret_code = os.environ.get("TEXTLN_SECRET_CODE", "6130caac9aabc6eb26433758d7898f4a") + app_id = os.environ.get("TEXTLN_APP_ID", "") + secret_code = os.environ.get("TEXTLN_SECRET_CODE", "") pdf_parser = TextLnParser(textln_api=textln_api, app_id=app_id, secret_code=secret_code) sections, tables = pdf_parser.parse_pdf( diff --git a/api/app/core/rag/llm/cv_model.py b/api/app/core/rag/llm/cv_model.py index 5f841433..24d4a35b 100644 --- a/api/app/core/rag/llm/cv_model.py +++ b/api/app/core/rag/llm/cv_model.py @@ -448,7 +448,7 @@ if __name__ == "__main__": # 准备配置vision_model信息 # 初始化 QWenCV vision_model = QWenCV( - key="sk-8e9e40cd171749858ce2d3722ea75669", + key="", model_name="qwen-vl-max", lang="Chinese", # 默认使用中文 base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py index ec15c50f..2cdc0f60 100644 --- a/api/app/core/tools/base.py +++ b/api/app/core/tools/base.py @@ -191,10 +191,14 @@ class BaseTool(ABC): execution_time=execution_time ) - def to_langchain_tool(self): - """转换为Langchain工具格式""" + def to_langchain_tool(self, operation: Optional[str] = None): + """转换为Langchain工具格式 + + Args: + operation: 特定操作(适用于有操作的工具) + """ from app.core.tools.langchain_adapter import LangchainAdapter - return LangchainAdapter.convert_tool(self) + return LangchainAdapter.convert_tool(self, operation) def __repr__(self): return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" \ No newline at end of file diff --git a/api/app/core/tools/builtin/operation_tool.py b/api/app/core/tools/builtin/operation_tool.py new file mode 100644 index 00000000..126541a8 --- /dev/null +++ b/api/app/core/tools/builtin/operation_tool.py @@ -0,0 +1,216 @@ +"""操作工具 - 为特定操作创建的工具包装器""" +from typing import List +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.models import ToolType + + +class OperationTool(BaseTool): + """操作工具 - 包装基础工具的特定操作""" + + def __init__(self, base_tool: BaseTool, operation: str): + self.base_tool = base_tool + self.operation = operation + super().__init__(base_tool.tool_id, base_tool.config) + + @property + def name(self) -> str: + return f"{self.base_tool.name}_{self.operation}" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.BUILTIN + + @property + def description(self) -> str: + return f"{self.base_tool.description} - {self.operation}" + + @property + def parameters(self) -> List[ToolParameter]: + """返回特定操作的参数""" + if self.base_tool.name == 'datetime_tool': + return self._get_datetime_params() + elif self.base_tool.name == 'json_tool': + return self._get_json_params() + else: + # 默认返回除operation外的所有参数 + return [p for p in self.base_tool.parameters if p.name != "operation"] + + def _get_datetime_params(self) -> List[ToolParameter]: + """获取datetime_tool特定操作的参数""" + if self.operation == "now": + return [ + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ) + ] + elif self.operation == "format": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ) + ] + elif self.operation == "convert_timezone": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="from_timezone", + type=ParameterType.STRING, + description="源时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ) + ] + elif self.operation == "timestamp_to_datetime": + return [ + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=True + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="Asia/Shanghai" + ) + ] + else: + return [] + + def _get_json_params(self) -> List[ToolParameter]: + """获取json_tool特定操作的参数""" + base_params = [ + ToolParameter( + name="input_data", + type=ParameterType.STRING, + description="输入数据(JSON字符串、YAML字符串或XML字符串)", + required=True + ) + ] + + if self.operation == "insert": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ), + ToolParameter( + name="new_value", + type=ParameterType.STRING, + description="新值(用于insert操作)", + required=True + ) + ] + elif self.operation == "replace": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ), + ToolParameter( + name="old_text", + type=ParameterType.STRING, + description="要替换的原文本(用于replace操作)", + required=True + ), + ToolParameter( + name="new_text", + type=ParameterType.STRING, + description="替换后的新文本(用于replace操作)", + required=True + ) + ] + elif self.operation == "delete": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ) + ] + elif self.operation == "parse": + return base_params + [ + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(如:$.user.name或users[0].name)", + required=True + ) + ] + else: + return base_params + + async def execute(self, **kwargs) -> ToolResult: + """执行特定操作""" + # 添加operation参数 + kwargs["operation"] = self.operation + return await self.base_tool.execute(**kwargs) \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index 0d656a8e..3dfe4c93 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -1,4 +1,5 @@ """自定义工具基类""" +import json import time from typing import Dict, Any, List, Optional import aiohttp @@ -135,6 +136,13 @@ class CustomTool(BaseTool): if not self.schema_content: return operations + + if isinstance(self.schema_content, str): + try: + self.schema_content = json.loads(self.schema_content) + except json.JSONDecodeError: + logger.error(f"无效的OpenAPI schema: {self.schema_content}") + return operations paths = self.schema_content.get("paths", {}) diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index 1b6969b9..ea5fdb96 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -21,24 +21,35 @@ class LangchainToolWrapper(LangchainBaseTool): # 内部工具实例 tool_instance: BaseTool = Field(..., description="内部工具实例") + # 特定操作(用于自定义工具) + operation: Optional[str] = Field(None, description="特定操作") class Config: arbitrary_types_allowed = True - def __init__(self, tool_instance: BaseTool, **kwargs): + def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs): """初始化Langchain工具包装器 Args: tool_instance: 内部工具实例 + operation: 特定操作(用于自定义工具) """ # 动态创建参数schema - args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters) + args_schema = LangchainAdapter._create_pydantic_schema( + tool_instance.parameters, operation + ) + + # 构建工具名称 + tool_name = tool_instance.name + if operation: + tool_name = f"{tool_instance.name}_{operation}" super().__init__( - name=tool_instance.name, + name=tool_name, description=tool_instance.description, args_schema=args_schema, - _tool_instance=tool_instance, + tool_instance=tool_instance, + operation=operation, **kwargs ) @@ -58,8 +69,12 @@ class LangchainToolWrapper(LangchainBaseTool): ) -> str: """异步执行工具""" try: + # 如果有特定操作,添加到参数中 + if self.operation: + kwargs["operation"] = self.operation + # 执行内部工具 - result = await self._tool_instance.safe_execute(**kwargs) + result = await self.tool_instance.safe_execute(**kwargs) # 转换结果为Langchain格式 return LangchainAdapter._format_result_for_langchain(result) @@ -73,24 +88,82 @@ class LangchainAdapter: """Langchain适配器 - 负责工具格式转换和标准化""" @staticmethod - def convert_tool(tool: BaseTool) -> LangchainToolWrapper: + def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper: """将内部工具转换为Langchain工具 Args: tool: 内部工具实例 + operation: 特定操作(适用于有操作的工具)或MCP工具名称 Returns: Langchain兼容的工具包装器 """ try: - wrapper = LangchainToolWrapper(tool_instance=tool) - logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") - return wrapper + # 处理MCP工具的特定工具名称 + if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation: + # 为MCP工具创建特定工具名称的实例 + mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation) + wrapper = LangchainToolWrapper(tool_instance=mcp_tool) + logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式") + return wrapper + elif operation and LangchainAdapter._tool_supports_operations(tool): + # 为支持多操作的工具创建特定操作实例 + if tool.tool_type.value == "custom": + # 自定义工具直接传递operation参数 + wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation) + else: + # 内置工具使用OperationTool包装 + operation_tool = LangchainAdapter._create_operation_tool(tool, operation) + wrapper = LangchainToolWrapper(tool_instance=operation_tool) + logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式") + return wrapper + else: + # 单个工具 + wrapper = LangchainToolWrapper(tool_instance=tool) + logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") + return wrapper except Exception as e: logger.error(f"工具转换失败: {tool.name}, 错误: {e}") raise + @staticmethod + def _tool_supports_operations(tool: BaseTool) -> bool: + """检查工具是否支持多操作""" + # 内置工具中支持操作的工具 + builtin_operation_tools = ['datetime_tool', 'json_tool'] + + # 检查内置工具 + if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools: + return True + + # 检查自定义工具(自定义工具通过解析OpenAPI schema支持多操作) + if tool.tool_type.value == "custom": + # 检查工具是否有多个操作 + if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1: + return True + # 或者检查参数中是否有operation参数 + for param in tool.parameters: + if param.name == "operation" and param.enum: + return True + + return False + + @staticmethod + def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool: + """为特定操作创建工具实例""" + if base_tool.tool_type.value == "builtin": + from app.core.tools.builtin.operation_tool import OperationTool + return OperationTool(base_tool, operation) + else: + raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}") + + @staticmethod + def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool: + """为MCP工具创建指定工具名称的实例""" + mcp_tool.set_current_tool(tool_name) + return mcp_tool + @staticmethod def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: """批量转换工具 @@ -110,15 +183,19 @@ class LangchainAdapter: except Exception as e: logger.error(f"跳过工具转换: {tool.name}, 错误: {e}") - logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具") + logger.info(f"批量转换完成: {len(converted_tools)} 个工具") return converted_tools @staticmethod - def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]: + def _create_pydantic_schema( + parameters: List[ToolParameter], + operation: Optional[str] = None + ) -> Type[BaseModel]: """根据工具参数创建Pydantic schema Args: parameters: 工具参数列表 + operation: 特定操作(用于过滤参数) Returns: Pydantic模型类 @@ -127,7 +204,12 @@ class LangchainAdapter: fields = {} annotations = {} - for param in parameters: + # 如果指定了operation,过滤掉operation参数 + filtered_params = parameters + if operation: + filtered_params = [p for p in parameters if p.name != "operation"] + + for param in filtered_params: # 确定Python类型 python_type = LangchainAdapter._get_python_type(param.type) @@ -169,9 +251,10 @@ class LangchainAdapter: "ToolArgsSchema", (BaseModel,), { + "__module__": __name__, "__annotations__": annotations, - **fields, - "Config": type("Config", (), {"extra": "forbid"}) + "model_config": {"extra": "forbid"}, + **fields } ) diff --git a/api/app/core/tools/mcp/__init__.py b/api/app/core/tools/mcp/__init__.py index 4c9519b3..b48aa096 100644 --- a/api/app/core/tools/mcp/__init__.py +++ b/api/app/core/tools/mcp/__init__.py @@ -1,12 +1,20 @@ -"""MCP工具模块""" +"""MCP 工具模块 - Model Context Protocol 支持""" -from app.core.tools.mcp.base import MCPTool -from app.core.tools.mcp.client import MCPClient, MCPConnectionPool -from app.core.tools.mcp.service_manager import MCPServiceManager +# 主要类导出 +from .base import MCPTool, MCPToolManager, MCPError +from .client import SimpleMCPClient, MCPConnectionError +from .service_manager import MCPServiceManager __all__ = [ + # 核心类 "MCPTool", - "MCPClient", - "MCPConnectionPool", + "MCPToolManager", + "MCPError", + + # 客户端类 + "SimpleMCPClient", + "MCPConnectionError", + + # 服务管理(简化版) "MCPServiceManager" ] \ No newline at end of file diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py index 3fa103ab..9e683ead 100644 --- a/api/app/core/tools/mcp/base.py +++ b/api/app/core/tools/mcp/base.py @@ -1,10 +1,9 @@ -"""MCP工具基类""" +"""MCP工具基类 - 整合版本""" import time -from typing import Dict, Any, List +from typing import List, Dict, Any from app.models.tool_model import ToolType -from app.core.tools.base import BaseTool -from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -14,215 +13,188 @@ class MCPTool(BaseTool): """MCP工具 - Model Context Protocol工具""" def __init__(self, tool_id: str, config: Dict[str, Any]): - """初始化MCP工具 - - Args: - tool_id: 工具ID - config: 工具配置 - """ super().__init__(tool_id, config) self.server_url = config.get("server_url", "") self.connection_config = config.get("connection_config", {}) self.available_tools = config.get("available_tools", []) - self._client = None - self._connected = False @property def name(self) -> str: - """工具名称""" return f"mcp_tool_{self.tool_id[:8]}" @property def description(self) -> str: - """工具描述""" return f"MCP工具 - 连接到 {self.server_url}" @property def tool_type(self) -> ToolType: - """工具类型""" return ToolType.MCP @property def parameters(self) -> List[ToolParameter]: - """工具参数定义""" - params = [] + """根据工具名称返回对应参数""" + # 如果有指定的工具名称,从 available_tools 中获取参数 + tool_name = getattr(self, '_current_tool_name', None) + if tool_name and self.available_tools: + for tool_info in self.available_tools: + if tool_info.get("tool_name") == tool_name: + arguments = tool_info.get("arguments", {}) + return self._generate_parameters_from_schema(arguments) - # 添加工具选择参数 - if len(self.available_tools) > 1: - params.append(ToolParameter( + # 默认返回通用参数 + return [ + ToolParameter( name="tool_name", type=ParameterType.STRING, - description="要调用的MCP工具名称", - required=True, - enum=self.available_tools - )) - - # 添加通用参数 - params.extend([ + description="要执行的工具名称", + required=True + ), ToolParameter( name="arguments", type=ParameterType.OBJECT, - description="工具参数(JSON对象)", + description="工具参数", required=False, default={} - ), - ToolParameter( - name="timeout", - type=ParameterType.INTEGER, - description="超时时间(秒)", - required=False, - default=30, - minimum=1, - maximum=300 ) - ]) + ] + + def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]: + """从参数schema生成参数列表""" + properties = arguments.get("properties", {}) + required_fields = arguments.get("required", []) + + params = [] + for param_name, param_def in properties.items(): + param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string")) + + params.append(ToolParameter( + name=param_name, + type=param_type, + description=param_def.get("description", f"参数: {param_name}"), + required=param_name in required_fields, + default=param_def.get("default"), + enum=param_def.get("enum"), + minimum=param_def.get("minimum"), + maximum=param_def.get("maximum") + )) return params + def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType: + """转换JSON Schema类型到ParameterType""" + type_mapping = { + "string": ParameterType.STRING, + "integer": ParameterType.INTEGER, + "number": ParameterType.NUMBER, + "boolean": ParameterType.BOOLEAN, + "array": ParameterType.ARRAY, + "object": ParameterType.OBJECT + } + return type_mapping.get(json_type, ParameterType.STRING) + + def set_current_tool(self, tool_name: str): + """设置当前工具名称,用于获取特定参数""" + self._current_tool_name = tool_name + async def execute(self, **kwargs) -> ToolResult: """执行MCP工具""" start_time = time.time() try: - # 确保连接 - if not self._connected: - await self.connect() - - # 确定要调用的工具 tool_name = kwargs.get("tool_name") - if not tool_name and len(self.available_tools) == 1: - tool_name = self.available_tools[0] - if not tool_name: - raise ValueError("必须指定要调用的MCP工具名称") + raise Exception("未指定工具名称") - if tool_name not in self.available_tools: - raise ValueError(f"MCP工具不存在: {tool_name}") - - # 获取参数 arguments = kwargs.get("arguments", {}) - timeout = kwargs.get("timeout", 30) - # 调用MCP工具 - result = await self._call_mcp_tool(tool_name, arguments, timeout) + from .client import SimpleMCPClient - execution_time = time.time() - start_time - return ToolResult.success_result( - data=result, - execution_time=execution_time - ) + client = SimpleMCPClient(self.server_url, self.connection_config) + async with client: + result = await client.call_tool(tool_name, arguments) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + except Exception as e: execution_time = time.time() - start_time + logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}") return ToolResult.error_result( error=str(e), - error_code="MCP_ERROR", + error_code="MCP_EXECUTION_ERROR", execution_time=execution_time ) + + +class MCPError(Exception): + """MCP 错误基类""" + pass + + +class MCPToolManager: + """MCP 工具管理器 - 简化版本""" - async def connect(self) -> bool: - """连接到MCP服务器""" + def __init__(self, db=None): + self.db = db + self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info + + async def discover_tools( + self, + server_url: str, + connection_config: Dict[str, Any] = None + ) -> tuple[bool, List[Dict[str, Any]], str | None]: + """发现 MCP 服务器上的工具""" try: - from .client import MCPClient + from .client import SimpleMCPClient - if self._connected: - return True - - self._client = MCPClient(self.server_url, self.connection_config) - - if await self._client.connect(): - self._connected = True - # 更新可用工具列表 - await self._update_available_tools() - logger.info(f"MCP服务器连接成功: {self.server_url}") - return True - else: - logger.error(f"MCP服务器连接失败: {self.server_url}") - return False + client = SimpleMCPClient(server_url, connection_config) + async with client: + tools = await client.list_tools() + + # 缓存工具信息 + self._tool_cache[server_url] = { + "tools": tools, + "connection_config": connection_config, + "last_updated": time.time() + } + + logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}") + return True, tools, None + except Exception as e: - logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}") - self._connected = False - return False + error_msg = f"发现工具失败: {e}" + logger.error(error_msg) + return False, [], error_msg - async def _update_available_tools(self): - """更新可用工具列表""" + async def test_tool_connection( + self, + server_url: str, + connection_config: Dict[str, Any] = None + ) -> Dict[str, Any]: + """测试工具连接""" try: - if self._client and self._connected: - tools = await self._client.list_tools() - self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] - logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具") - except Exception as e: - logger.error(f"更新MCP工具列表失败: {e}") - - async def disconnect(self) -> bool: - """断开MCP服务器连接""" - try: - if self._client: - await self._client.disconnect() - self._client = None + from .client import SimpleMCPClient - self._connected = False - logger.info(f"MCP服务器连接已断开: {self.server_url}") - return True - - except Exception as e: - logger.error(f"断开MCP服务器连接失败: {e}") - return False - - def get_health_status(self) -> Dict[str, Any]: - """获取MCP服务健康状态""" - return { - "connected": self._connected, - "server_url": self.server_url, - "available_tools": self.available_tools, - "last_check": time.time() - } - - async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any: - """调用MCP工具""" - if not self._client or not self._connected: - raise Exception("MCP客户端未连接") - - try: - result = await self._client.call_tool(tool_name, arguments, timeout) - return result - except Exception as e: - logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}") - raise - - async def list_available_tools(self) -> List[Dict[str, Any]]: - """列出可用的MCP工具""" - try: - if not self._connected: - await self.connect() - - if self._client: - tools = await self._client.list_tools() - self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] - return tools - - return [] - - except Exception as e: - logger.error(f"获取MCP工具列表失败: {e}") - return [] - - def test_connection(self) -> Dict[str, Any]: - """测试MCP连接""" - try: - # 这里应该实现同步的连接测试 - # 为了简化,返回基本信息 - return { - "success": bool(self.server_url), - "server_url": self.server_url, - "connected": self._connected, - "available_tools_count": len(self.available_tools), - "message": "MCP配置有效" if self.server_url else "缺少服务器URL配置" - } + client = SimpleMCPClient(server_url, connection_config) + async with client: + tools = await client.list_tools() + + return { + "success": True, + "tools_count": len(tools), + "tools": [tool.get("name") for tool in tools], + "message": "连接成功" + } + except Exception as e: return { "success": False, - "error": str(e) + "error": str(e), + "message": "连接失败" } \ No newline at end of file diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index a1d2ecaa..2901b7ca 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -1,9 +1,8 @@ -"""MCP客户端 - Model Context Protocol客户端实现""" +"""MCP客户端 - 简化版本""" import asyncio import json import time -from typing import Dict, Any, List, Optional, Callable -from urllib.parse import urlparse +from typing import Dict, Any, List import aiohttp import websockets from websockets.exceptions import ConnectionClosed @@ -18,139 +17,156 @@ class MCPConnectionError(Exception): pass -class MCPProtocolError(Exception): - """MCP协议错误""" - pass - - -class MCPClient: - """MCP客户端 - 支持HTTP和WebSocket连接""" +class SimpleMCPClient: + """简化的 MCP 客户端""" def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): - """初始化MCP客户端 - - Args: - server_url: MCP服务器URL - connection_config: 连接配置 - """ self.server_url = server_url self.connection_config = connection_config or {} + self.timeout = self.connection_config.get("timeout", 30) - # 解析URL确定连接类型 - parsed_url = urlparse(server_url) - self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http" + # 确定连接类型 + self.is_websocket = server_url.startswith(("ws://", "wss://")) # 连接状态 - self._connected = False self._websocket = None self._session = None - - # 请求管理 self._request_id = 0 - self._pending_requests: Dict[str, asyncio.Future] = {} - - # 连接池配置 - self.max_connections = self.connection_config.get("max_connections", 10) - self.connection_timeout = self.connection_config.get("timeout", 30) - self.retry_attempts = self.connection_config.get("retry_attempts", 3) - self.retry_delay = self.connection_config.get("retry_delay", 1) - - # 健康检查 - self.health_check_interval = self.connection_config.get("health_check_interval", 60) - self._health_check_task = None - self._last_health_check = None - - # 事件回调 - self._on_connect_callbacks: List[Callable] = [] - self._on_disconnect_callbacks: List[Callable] = [] - self._on_error_callbacks: List[Callable] = [] + self._pending_requests = {} - async def connect(self) -> bool: - """连接到MCP服务器 - - Returns: - 连接是否成功 - """ + async def __aenter__(self): + """异步上下文管理器入口""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.disconnect() + + async def connect(self): + """建立连接""" try: - if self._connected: - return True - - logger.info(f"连接MCP服务器: {self.server_url}") - - if self.connection_type == "websocket": - success = await self._connect_websocket() + if self.is_websocket: + await self._connect_websocket() else: - success = await self._connect_http() - - if success: - self._connected = True - await self._start_health_check() - await self._notify_connect_callbacks() - logger.info(f"MCP服务器连接成功: {self.server_url}") - - return success - + await self._connect_http() except Exception as e: - logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}") - await self._notify_error_callbacks(e) - return False + logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}") + raise MCPConnectionError(f"连接失败: {e}") - async def disconnect(self) -> bool: - """断开MCP服务器连接 - - Returns: - 断开是否成功 - """ + async def disconnect(self): + """断开连接""" try: - if not self._connected: - return True - - logger.info(f"断开MCP服务器连接: {self.server_url}") - - # 停止健康检查 - await self._stop_health_check() - - # 取消所有待处理的请求 - for future in self._pending_requests.values(): - if not future.done(): - future.cancel() - self._pending_requests.clear() - - # 断开连接 - if self.connection_type == "websocket" and self._websocket: + if self._websocket: await self._websocket.close() self._websocket = None - elif self._session: + + if self._session: await self._session.close() self._session = None - - self._connected = False - await self._notify_disconnect_callbacks() - logger.info(f"MCP服务器连接已断开: {self.server_url}") - - return True - + except Exception as e: - logger.error(f"断开MCP服务器连接失败: {e}") - return False + logger.error(f"断开连接失败: {e}") - def _build_auth_headers(self) -> Dict[str, str]: - """构建认证头""" - headers = {} - auth_type = self.connection_config.get("auth_type", "none") + async def _connect_websocket(self): + """WebSocket 连接""" + headers = self._build_headers() + + self._websocket = await websockets.connect( + self.server_url, + extra_headers=headers, + timeout=self.timeout + ) + + # 启动消息处理 + asyncio.create_task(self._handle_websocket_messages()) + + # 发送初始化消息 + await self._send_initialize() + + async def _connect_http(self): + """HTTP 连接""" + headers = self._build_headers() + timeout = aiohttp.ClientTimeout(total=self.timeout) + + self._session = aiohttp.ClientSession( + headers=headers, + timeout=timeout + ) + + # 对于 ModelScope MCP 服务,需要先发送初始化请求 + if "modelscope.net" in self.server_url: + await self._initialize_modelscope_session() + + async def _initialize_modelscope_session(self): + """初始化 ModelScope MCP 会话""" + init_request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "MemoryBear", + "version": "1.0.0" + } + } + } + + try: + async with self._session.post( + self.server_url, + json=init_request + ) as response: + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") + + init_response = await response.json() + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + # 获取 session ID + session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") + if session_id: + self._session.headers.update({"Mcp-Session-Id": session_id}) + + # 发送 initialized 通知 + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + + async with self._session.post( + self.server_url, + json=initialized_notification + ) as notif_response: + pass + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"初始化连接失败: {e}") + + def _build_headers(self) -> Dict[str, str]: + """构建请求头""" + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + + # 添加认证头 auth_config = self.connection_config.get("auth_config", {}) + auth_type = self.connection_config.get("auth_type", "none") - if auth_type == "api_key": - api_key = auth_config.get("api_key") - key_name = auth_config.get("key_name", "X-API-Key") - if api_key: - headers[key_name] = api_key - - elif auth_type == "bearer_token": + if auth_type == "bearer_token": token = auth_config.get("token") if token: headers["Authorization"] = f"Bearer {token}" - + elif auth_type == "api_key": + key = auth_config.get("api_key") + header_name = auth_config.get("key_name", "X-API-Key") + if key: + headers[header_name] = key elif auth_type == "basic_auth": username = auth_config.get("username") password = auth_config.get("password") @@ -161,160 +177,63 @@ class MCPClient: return headers - async def _connect_websocket(self) -> bool: - """建立WebSocket连接""" - try: - # WebSocket连接配置 - extra_headers = self.connection_config.get("headers", {}) - auth_headers = self._build_auth_headers() - extra_headers.update(auth_headers) - - self._websocket = await websockets.connect( - self.server_url, - extra_headers=extra_headers, - timeout=self.connection_timeout - ) - - # 启动消息监听 - asyncio.create_task(self._websocket_message_handler()) - - # 发送初始化消息 - init_message = { - "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": { - "tools": {} - }, - "clientInfo": { - "name": "ToolManagementSystem", - "version": "1.0.0" - } + async def _send_initialize(self): + """发送初始化消息""" + init_message = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "MemoryBear", + "version": "1.0.0" } } - - await self._websocket.send(json.dumps(init_message)) - - # 等待初始化响应 - response = await asyncio.wait_for( - self._websocket.recv(), - timeout=self.connection_timeout - ) - - init_response = json.loads(response) - if init_response.get("error", None) is not None: - raise MCPProtocolError(f"初始化失败: {init_response['error']}") - - return True - - except Exception as e: - logger.error(f"WebSocket连接失败: {e}") - return False + } + + await self._websocket.send(json.dumps(init_message)) + + # 等待初始化响应 + response = await asyncio.wait_for( + self._websocket.recv(), + timeout=self.timeout + ) + + init_response = json.loads(response) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") - async def _connect_http(self) -> bool: - """建立HTTP连接""" - try: - # HTTP会话配置 - timeout = aiohttp.ClientTimeout(total=self.connection_timeout) - headers = self.connection_config.get("headers", {}) - auth_headers = self._build_auth_headers() - headers.update(auth_headers) - - self._session = aiohttp.ClientSession( - timeout=timeout, - headers=headers - ) - - # 测试连接 - test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health" - - async with self._session.get(test_url) as response: - if response.status == 200: - return True - else: - # 尝试根路径 - async with self._session.get(self.server_url) as root_response: - return root_response.status < 400 - - except Exception as e: - logger.error(f"HTTP连接失败: {e}") - if self._session: - await self._session.close() - self._session = None - return False - - async def _websocket_message_handler(self): - """WebSocket消息处理器""" + async def _handle_websocket_messages(self): + """处理 WebSocket 消息""" try: while self._websocket and not self._websocket.closed: try: message = await self._websocket.recv() - await self._handle_message(json.loads(message)) + data = json.loads(message) + + # 处理响应 + if "id" in data: + request_id = str(data["id"]) + if request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(data) + except ConnectionClosed: break - except json.JSONDecodeError as e: - logger.error(f"解析WebSocket消息失败: {e}") except Exception as e: logger.error(f"处理WebSocket消息失败: {e}") except Exception as e: - logger.error(f"WebSocket消息处理器异常: {e}") - finally: - self._connected = False - await self._notify_disconnect_callbacks() + logger.error(f"WebSocket消息处理异常: {e}") - async def _handle_message(self, message: Dict[str, Any]): - """处理收到的消息""" - try: - # 检查是否是响应消息 - if "id" in message: - request_id = str(message["id"]) - if request_id in self._pending_requests: - future = self._pending_requests.pop(request_id) - if not future.done(): - future.set_result(message) - - # 处理通知消息 - elif "method" in message: - await self._handle_notification(message) - - except Exception as e: - logger.error(f"处理消息失败: {e}") - - @staticmethod - async def _handle_notification(message: Dict[str, Any]): - """处理通知消息""" - method = message.get("method") - params = message.get("params", {}) - - logger.debug(f"收到MCP通知: {method}, 参数: {params}") - - # 这里可以根据需要处理特定的通知 - # 例如:工具列表更新、服务器状态变化等 - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: - """调用MCP工具 - - Args: - tool_name: 工具名称 - arguments: 工具参数 - timeout: 超时时间(秒) - - Returns: - 工具执行结果 - - Raises: - MCPConnectionError: 连接错误 - MCPProtocolError: 协议错误 - """ - if not self._connected: - raise MCPConnectionError("MCP客户端未连接") - + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """调用工具""" request_data = { "jsonrpc": "2.0", - "id": self._get_next_request_id(), + "id": self._get_request_id(), "method": "tools/call", "params": { "name": tool_name, @@ -322,343 +241,69 @@ class MCPClient: } } - try: - response = await self._send_request(request_data, timeout) - - if response.get("error", None) is not None: - error = response["error"] - raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") - - return response.get("result", {}) - - except asyncio.TimeoutError: - raise MCPProtocolError(f"工具调用超时: {tool_name}") + if self.is_websocket: + response = await self._send_websocket_request(request_data) + else: + response = await self._send_http_request(request_data) + + if "error" in response: + error = response["error"] + raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response.get("result", {}) - async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]: - """获取可用工具列表 - - Args: - timeout: 超时时间(秒) - - Returns: - 工具列表 - - Raises: - MCPConnectionError: 连接错误 - MCPProtocolError: 协议错误 - """ - if not self._connected: - raise MCPConnectionError("MCP客户端未连接") - + async def list_tools(self) -> List[Dict[str, Any]]: + """获取工具列表""" request_data = { "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "tools/list" + "id": self._get_request_id(), + "method": "tools/list", + "params": {} } - try: - response = await self._send_request(request_data, timeout) - - if response.get("error", None) is not None: - error = response["error"] - raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}") - - result = response.get("result", {}) - return result.get("tools", []) - - except asyncio.TimeoutError: - raise MCPProtocolError("获取工具列表超时") - - async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: - """发送请求并等待响应 - - Args: - request_data: 请求数据 - timeout: 超时时间(秒) - - Returns: - 响应数据 - """ - if self.connection_type == "websocket": - request_id = str(request_data["id"]) - return await self._send_websocket_request(request_data, request_id, timeout) + if self.is_websocket: + response = await self._send_websocket_request(request_data) else: - return await self._send_http_request(request_data, timeout) - - async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]: - """发送WebSocket请求""" - if not self._websocket or self._websocket.closed: - raise MCPConnectionError("WebSocket连接已断开") + response = await self._send_http_request(request_data) - # 创建Future等待响应 + if "error" in response: + error = response["error"] + raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}") + + result = response.get("result", {}) + return result.get("tools", []) + + async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """发送WebSocket请求""" + request_id = str(request_data["id"]) future = asyncio.Future() self._pending_requests[request_id] = future try: - # 发送请求 await self._websocket.send(json.dumps(request_data)) - - # 等待响应 - response = await asyncio.wait_for(future, timeout=timeout) + response = await asyncio.wait_for(future, timeout=self.timeout) return response - except asyncio.TimeoutError: - await self._pending_requests.pop(request_id, None) + self._pending_requests.pop(request_id, None) raise - except Exception as e: - await self._pending_requests.pop(request_id, None) - raise MCPConnectionError(f"发送WebSocket请求失败: {e}") - async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """发送HTTP请求""" - if not self._session: - raise MCPConnectionError("HTTP会话未建立") - try: - url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp" - async with self._session.post( - url, - json=request_data, - timeout=aiohttp.ClientTimeout(total=timeout) + self.server_url, + json=request_data ) as response: - if response.status == 200: - return await response.json() - else: - async with self._session.post( - self.server_url, - json=request_data, - timeout=aiohttp.ClientTimeout(total=timeout) - ) as root_response: - if root_response.status != 200: - error_text = await root_response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") + + return await response.json() except aiohttp.ClientError as e: raise MCPConnectionError(f"HTTP请求失败: {e}") - async def health_check(self) -> Dict[str, Any]: - """执行健康检查 - - Returns: - 健康状态信息 - """ - try: - if not self._connected: - return { - "healthy": False, - "error": "未连接", - "timestamp": time.time() - } - - # 发送ping请求 - request_data = { - "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "ping" - } - - start_time = time.time() - response = await self._send_request(request_data, timeout=5) - response_time = round((time.time() - start_time) * 1000) - - self._last_health_check = round(time.time() * 1000) - - return { - "healthy": True, - "response_time": response_time, - "timestamp": self._last_health_check, - "server_info": response.get("result", {}) - } - - except Exception as e: - return { - "healthy": False, - "error": str(e), - "timestamp": time.time() - } - - async def _start_health_check(self): - """启动健康检查任务""" - if self.health_check_interval > 0: - self._health_check_task = asyncio.create_task(self._health_check_loop()) - - async def _stop_health_check(self): - """停止健康检查任务""" - if self._health_check_task: - self._health_check_task.cancel() - try: - await self._health_check_task - except asyncio.CancelledError: - pass - self._health_check_task = None - - async def _health_check_loop(self): - """健康检查循环""" - try: - while self._connected: - await asyncio.sleep(self.health_check_interval) - - if self._connected: - health_status = await self.health_check() - if not health_status["healthy"]: - logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}") - # 可以在这里实现重连逻辑 - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"健康检查循环异常: {e}") - - def _get_next_request_id(self) -> str: - """获取下一个请求ID""" + def _get_request_id(self) -> str: + """获取请求ID""" self._request_id += 1 - return f"req_{self._request_id}_{int(time.time() * 1000)}" - - # 事件回调管理 - def on_connect(self, callback: Callable): - """注册连接回调""" - self._on_connect_callbacks.append(callback) - - def on_disconnect(self, callback: Callable): - """注册断开连接回调""" - self._on_disconnect_callbacks.append(callback) - - def on_error(self, callback: Callable): - """注册错误回调""" - self._on_error_callbacks.append(callback) - - async def _notify_connect_callbacks(self): - """通知连接回调""" - for callback in self._on_connect_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback() - else: - callback() - except Exception as e: - logger.error(f"连接回调执行失败: {e}") - - async def _notify_disconnect_callbacks(self): - """通知断开连接回调""" - for callback in self._on_disconnect_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback() - else: - callback() - except Exception as e: - logger.error(f"断开连接回调执行失败: {e}") - - async def _notify_error_callbacks(self, error: Exception): - """通知错误回调""" - for callback in self._on_error_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback(error) - else: - callback(error) - except Exception as e: - logger.error(f"错误回调执行失败: {e}") - - @property - def is_connected(self) -> bool: - """检查是否已连接""" - return self._connected - - @property - def last_health_check(self) -> Optional[float]: - """获取最后一次健康检查时间""" - return self._last_health_check - - def get_connection_info(self) -> Dict[str, Any]: - """获取连接信息""" - return { - "server_url": self.server_url, - "connection_type": self.connection_type, - "connected": self._connected, - "last_health_check": self._last_health_check, - "pending_requests": len(self._pending_requests), - "config": self.connection_config - } - - async def __aenter__(self): - """异步上下文管理器入口""" - await self.connect() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" - await self.disconnect() - - -class MCPConnectionPool: - """MCP连接池 - 管理多个MCP客户端连接""" - - def __init__(self, max_connections: int = 10): - """初始化连接池 - - Args: - max_connections: 最大连接数 - """ - self.max_connections = max_connections - self._clients: Dict[str, MCPClient] = {} - self._lock = asyncio.Lock() - - async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient: - """获取或创建MCP客户端 - - Args: - server_url: 服务器URL - connection_config: 连接配置 - - Returns: - MCP客户端实例 - """ - async with self._lock: - if server_url in self._clients: - client = self._clients[server_url] - if client.is_connected: - return client - else: - # 尝试重连 - if await client.connect(): - return client - else: - # 移除失效的客户端 - del self._clients[server_url] - - # 检查连接数限制 - if len(self._clients) >= self.max_connections: - # 移除最旧的连接 - oldest_url = next(iter(self._clients)) - await self._clients[oldest_url].disconnect() - del self._clients[oldest_url] - - # 创建新客户端 - client = MCPClient(server_url, connection_config) - if await client.connect(): - self._clients[server_url] = client - return client - else: - raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}") - - async def disconnect_all(self): - """断开所有连接""" - async with self._lock: - for client in self._clients.values(): - await client.disconnect() - self._clients.clear() - - def get_pool_status(self) -> Dict[str, Any]: - """获取连接池状态""" - return { - "total_connections": len(self._clients), - "max_connections": self.max_connections, - "connections": { - url: client.get_connection_info() - for url, client in self._clients.items() - } - } \ No newline at end of file + return f"req_{self._request_id}_{int(time.time() * 1000)}" \ No newline at end of file diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py index f7349201..2144999a 100644 --- a/api/app/core/tools/mcp/service_manager.py +++ b/api/app/core/tools/mcp/service_manager.py @@ -1,6 +1,4 @@ -"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控""" -import asyncio -import time +"""MCP服务管理器 - 简化版本""" import uuid from typing import Dict, Any, List, Optional, Tuple from datetime import datetime @@ -8,133 +6,53 @@ from sqlalchemy.orm import Session from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.core.logging_config import get_business_logger -from app.core.tools.mcp.client import MCPClient, MCPConnectionPool +from app.core.tools.mcp.base import MCPToolManager logger = get_business_logger() class MCPServiceManager: - """MCP服务管理器 - 管理MCP服务的生命周期""" + """MCP服务管理器 - 简化版本,主要用于工具创建""" - def __init__(self, db: Session): - """初始化MCP服务管理器 - - Args: - db: 数据库会话 - """ + def __init__(self, db: Session = None): self.db = db - self.connection_pool = MCPConnectionPool(max_connections=20) - - # 服务状态管理 - self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info - self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task - - # 配置 - self.health_check_interval = 60 # 健康检查间隔(秒) - self.max_retry_attempts = 3 # 最大重试次数 - self.retry_delay = 5 # 重试延迟(秒) - - # 状态 - self._running = False - self._manager_task = None + self.tool_manager = MCPToolManager(db) if db else None - async def start(self): - """启动服务管理器""" - if self._running: - return - - self._running = True - logger.info("MCP服务管理器启动") - - # 加载现有服务 - await self._load_existing_services() - - # 启动管理任务 - self._manager_task = asyncio.create_task(self._management_loop()) - - async def stop(self): - """停止服务管理器""" - if not self._running: - return - - self._running = False - logger.info("MCP服务管理器停止") - - # 停止管理任务 - if self._manager_task: - self._manager_task.cancel() - try: - await self._manager_task - except asyncio.CancelledError: - pass - - # 停止所有监控任务 - for task in self._monitoring_tasks.values(): - task.cancel() - - if self._monitoring_tasks: - await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True) - - self._monitoring_tasks.clear() - - # 断开所有连接 - await self.connection_pool.disconnect_all() - - async def register_service( + async def create_mcp_tool( self, server_url: str, connection_config: Dict[str, Any], tenant_id: uuid.UUID, + tool_name: str, service_name: str = None ) -> Tuple[bool, str, Optional[str]]: - """注册MCP服务 + """创建单个MCP工具 Args: server_url: 服务器URL connection_config: 连接配置 tenant_id: 租户ID - service_name: 服务名称(可选) + tool_name: 具体工具名称 + service_name: 服务名称 Returns: - (是否成功, 服务ID或错误信息, 错误详情) + (是否成功, 工具ID或错误信息, 错误详情) """ try: - # 检查服务是否已存在 - existing_service = self.db.query(MCPToolConfig).filter( - MCPToolConfig.server_url == server_url - ).first() - - if existing_service: - return False, "服务已存在", f"URL {server_url} 已被注册" - - # 测试连接 - try: - client = MCPClient(server_url, connection_config) - if not await client.connect(): - return False, "连接测试失败", "无法连接到MCP服务器" - - # 获取可用工具 - available_tools = await client.list_tools() - tool_names = [tool.get("name") for tool in available_tools if tool.get("name")] - - await client.disconnect() - - except Exception as e: - return False, "连接测试失败", str(e) + if not service_name: + service_name = f"mcp_{tool_name}" # 创建工具配置 - if not service_name: - service_name = f"mcp_service_{server_url.split('/')[-1]}" - tool_config = ToolConfig( name=service_name, - description=f"MCP服务 - {server_url}", + description=f"MCP工具: {tool_name}", tool_type=ToolType.MCP.value, tenant_id=tenant_id, - version="1.0.0", + status=ToolStatus.AVAILABLE.value, config_data={ "server_url": server_url, - "connection_config": connection_config + "connection_config": connection_config, + "tool_name": tool_name } ) @@ -146,460 +64,22 @@ class MCPServiceManager: id=tool_config.id, server_url=server_url, connection_config=connection_config, - available_tools=tool_names, - health_status="healthy", + available_tools=[tool_name], + health_status="unknown", last_health_check=datetime.now() ) self.db.add(mcp_config) self.db.commit() - service_id = str(tool_config.id) - - # 添加到内存管理 - self._services[service_id] = { - "id": service_id, - "server_url": server_url, - "connection_config": connection_config, - "tenant_id": tenant_id, - "available_tools": tool_names, - "status": "healthy", - "last_health_check": time.time(), - "retry_count": 0, - "created_at": time.time() - } - - # 启动监控 - await self._start_service_monitoring(service_id) - - logger.info(f"MCP服务注册成功: {service_id} ({server_url})") - return True, service_id, None + logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})") + return True, str(tool_config.id), None except Exception as e: self.db.rollback() - logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}") - return False, "注册失败", str(e) + logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}") + return False, "创建失败", str(e) - async def unregister_service(self, service_id: str) -> Tuple[bool, str]: - """注销MCP服务 - - Args: - service_id: 服务ID - - Returns: - (是否成功, 错误信息) - """ - try: - # 从数据库删除 - tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) - if not tool_config: - return False, "服务不存在" - - self.db.delete(tool_config) - self.db.commit() - - # 停止监控 - await self._stop_service_monitoring(service_id) - - # 从内存移除 - if service_id in self._services: - del self._services[service_id] - - logger.info(f"MCP服务注销成功: {service_id}") - return True, "" - - except Exception as e: - self.db.rollback() - logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}") - return False, str(e) - - async def update_service( - self, - service_id: str, - connection_config: Dict[str, Any] = None, - enabled: bool = None - ) -> Tuple[bool, str]: - """更新MCP服务配置 - - Args: - service_id: 服务ID - connection_config: 新的连接配置 - enabled: 是否启用 - - Returns: - (是否成功, 错误信息) - """ - try: - # 更新数据库 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if not mcp_config: - return False, "服务不存在" - - tool_config = mcp_config.base_config - - if connection_config is not None: - mcp_config.connection_config = connection_config - tool_config.config_data["connection_config"] = connection_config - - if enabled is not None: - tool_config.is_enabled = enabled - - self.db.commit() - - # 更新内存状态 - if service_id in self._services: - if connection_config is not None: - self._services[service_id]["connection_config"] = connection_config - - # 如果配置有变化,重启监控 - if connection_config is not None: - await self._restart_service_monitoring(service_id) - - logger.info(f"MCP服务更新成功: {service_id}") - return True, "" - - except Exception as e: - self.db.rollback() - logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}") - return False, str(e) - - async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]: - """获取服务状态 - - Args: - service_id: 服务ID - - Returns: - 服务状态信息 - """ - if service_id not in self._services: - return None - - service_info = self._services[service_id].copy() - - # 添加实时健康检查 - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - health_status = await client.health_check() - service_info["real_time_health"] = health_status - - except Exception as e: - service_info["real_time_health"] = { - "healthy": False, - "error": str(e), - "timestamp": time.time() - } - - return service_info - - async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]: - """列出所有服务 - - Args: - tenant_id: 租户ID过滤 - - Returns: - 服务列表 - """ - services = [] - - for service_id, service_info in self._services.items(): - if tenant_id and service_info["tenant_id"] != tenant_id: - continue - - services.append(service_info.copy()) - - return services - - async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]: - """获取服务的可用工具 - - Args: - service_id: 服务ID - - Returns: - 工具列表 - """ - if service_id not in self._services: - return [] - - service_info = self._services[service_id] - - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - tools = await client.list_tools() - - # 更新缓存的工具列表 - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - service_info["available_tools"] = tool_names - - # 更新数据库 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if mcp_config: - mcp_config.available_tools = tool_names - self.db.commit() - - return tools - - except Exception as e: - logger.error(f"获取服务工具失败: {service_id}, 错误: {e}") - return [] - - async def call_service_tool( - self, - service_id: str, - tool_name: str, - arguments: Dict[str, Any], - timeout: int = 30 - ) -> Dict[str, Any]: - """调用服务工具 - - Args: - service_id: 服务ID - tool_name: 工具名称 - arguments: 工具参数 - timeout: 超时时间 - - Returns: - 执行结果 - """ - if service_id not in self._services: - raise ValueError(f"服务不存在: {service_id}") - - service_info = self._services[service_id] - - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - result = await client.call_tool(tool_name, arguments, timeout) - - # 更新服务状态为健康 - service_info["status"] = "healthy" - service_info["last_health_check"] = time.time() - service_info["retry_count"] = 0 - - return result - - except Exception as e: - # 更新服务状态为错误 - service_info["status"] = "error" - service_info["last_error"] = str(e) - service_info["retry_count"] += 1 - - logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}") - raise - - async def _load_existing_services(self): - """加载现有服务""" - try: - mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter( - ToolConfig.status == ToolStatus.AVAILABLE.value, - ToolConfig.tool_type == ToolType.MCP.value - ).all() - - for mcp_config in mcp_configs: - tool_config = mcp_config.base_config - service_id = str(mcp_config.id) - - self._services[service_id] = { - "id": service_id, - "server_url": mcp_config.server_url, - "connection_config": mcp_config.connection_config or {}, - "tenant_id": tool_config.tenant_id, - "available_tools": mcp_config.available_tools or [], - "status": mcp_config.health_status or "unknown", - "last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0, - "retry_count": 0, - "created_at": tool_config.created_at.timestamp() - } - - # 启动监控 - await self._start_service_monitoring(service_id) - - logger.info(f"加载了 {len(mcp_configs)} 个MCP服务") - - except Exception as e: - logger.error(f"加载现有服务失败: {e}") - - async def _start_service_monitoring(self, service_id: str): - """启动服务监控""" - if service_id in self._monitoring_tasks: - return - - task = asyncio.create_task(self._monitor_service(service_id)) - self._monitoring_tasks[service_id] = task - - async def _stop_service_monitoring(self, service_id: str): - """停止服务监控""" - if service_id in self._monitoring_tasks: - task = self._monitoring_tasks.pop(service_id) - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - async def _restart_service_monitoring(self, service_id: str): - """重启服务监控""" - await self._stop_service_monitoring(service_id) - await self._start_service_monitoring(service_id) - - async def _monitor_service(self, service_id: str): - """监控单个服务""" - try: - while self._running and service_id in self._services: - service_info = self._services[service_id] - - try: - # 执行健康检查 - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - health_status = await client.health_check() - - if health_status["healthy"]: - # 服务健康 - service_info["status"] = "healthy" - service_info["retry_count"] = 0 - - # 更新工具列表 - try: - tools = await client.list_tools() - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - service_info["available_tools"] = tool_names - except Exception as e: - logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}") - - else: - # 服务不健康 - service_info["status"] = "unhealthy" - service_info["last_error"] = health_status.get("error", "健康检查失败") - service_info["retry_count"] += 1 - - service_info["last_health_check"] = time.time() - - # 更新数据库 - await self._update_service_health_in_db(service_id, health_status) - - except Exception as e: - # 监控异常 - service_info["status"] = "error" - service_info["last_error"] = str(e) - service_info["retry_count"] += 1 - service_info["last_health_check"] = time.time() - - logger.error(f"服务监控异常: {service_id}, 错误: {e}") - - # 如果重试次数过多,暂停监控 - if service_info["retry_count"] >= self.max_retry_attempts: - logger.warning(f"服务 {service_id} 重试次数过多,暂停监控") - await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间 - service_info["retry_count"] = 0 # 重置重试计数 - - # 等待下次检查 - await asyncio.sleep(self.health_check_interval) - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"服务监控任务异常: {service_id}, 错误: {e}") - - async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]): - """更新数据库中的服务健康状态""" - try: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if mcp_config: - mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy" - mcp_config.last_health_check = datetime.now() - - if not health_status["healthy"]: - mcp_config.error_message = health_status.get("error", "") - else: - mcp_config.error_message = None - - self.db.commit() - - except Exception as e: - logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}") - self.db.rollback() - - async def _management_loop(self): - """管理循环 - 处理服务清理等任务""" - try: - while self._running: - # 清理失效的服务 - await self._cleanup_failed_services() - - # 等待下次循环 - await asyncio.sleep(300) # 5分钟 - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"管理循环异常: {e}") - - async def _cleanup_failed_services(self): - """清理长期失效的服务""" - try: - current_time = time.time() - cleanup_threshold = 24 * 60 * 60 # 24小时 - - services_to_cleanup = [] - - for service_id, service_info in self._services.items(): - # 检查服务是否长期失效 - if (service_info["status"] in ["error", "unhealthy"] and - current_time - service_info["last_health_check"] > cleanup_threshold): - - services_to_cleanup.append(service_id) - - for service_id in services_to_cleanup: - logger.warning(f"清理长期失效的服务: {service_id}") - - # 停止监控但不删除数据库记录 - await self._stop_service_monitoring(service_id) - - # 标记为禁用 - tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) - if tool_config: - tool_config.is_enabled = False - self.db.commit() - - # 从内存移除 - del self._services[service_id] - - except Exception as e: - logger.error(f"清理失效服务失败: {e}") - - def get_manager_status(self) -> Dict[str, Any]: - """获取管理器状态""" - return { - "running": self._running, - "total_services": len(self._services), - "healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]), - "unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]), - "monitoring_tasks": len(self._monitoring_tasks), - "connection_pool_status": self.connection_pool.get_pool_status() - } \ No newline at end of file + def get_tool_manager(self) -> MCPToolManager: + """获取工具管理器实例""" + return self.tool_manager \ No newline at end of file diff --git a/api/app/core/validators/memory_config_validators.py b/api/app/core/validators/memory_config_validators.py index eb2aaad8..6ccf3ddb 100644 --- a/api/app/core/validators/memory_config_validators.py +++ b/api/app/core/validators/memory_config_validators.py @@ -64,6 +64,11 @@ def validate_model_exists_and_active( ) -> tuple[str, bool]: """Validate that a model exists and is active. + This function performs tenant-aware model validation with detailed error messages: + - If model doesn't exist at all: "Model not found" + - If model exists but belongs to different tenant: "Model belongs to different tenant" with details + - If model exists and accessible but inactive: "Model is inactive" + Args: model_id: Model UUID to validate model_type: Type of model ("llm", "embedding", "rerank") @@ -76,7 +81,7 @@ def validate_model_exists_and_active( Tuple of (model_name, is_active) Raises: - ModelNotFoundError: If model does not exist + ModelNotFoundError: If model does not exist or belongs to different tenant ModelInactiveError: If model exists but is inactive """ from app.repositories.model_repository import ModelConfigRepository @@ -84,21 +89,48 @@ def validate_model_exists_and_active( start_time = time.time() try: + # First check if model exists at all (without tenant filtering) + model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) + + # Then check with tenant filtering model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) elapsed_ms = (time.time() - start_time) * 1000 if not model: - logger.warning( - "Model not found", - extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms} - ) - raise ModelNotFoundError( - model_id=model_id, - model_type=model_type, - config_id=config_id, - workspace_id=workspace_id, - message=f"{model_type.title()} model {model_id} not found" - ) + if model_without_tenant: + # Model exists but belongs to different tenant + logger.warning( + "Model belongs to different tenant", + extra={ + "model_id": str(model_id), + "model_type": model_type, + "model_name": model_without_tenant.name, + "model_tenant_id": str(model_without_tenant.tenant_id), + "requested_tenant_id": str(tenant_id), + "is_public": model_without_tenant.is_public, + "elapsed_ms": elapsed_ms + } + ) + raise ModelNotFoundError( + model_id=model_id, + model_type=model_type, + config_id=config_id, + workspace_id=workspace_id, + message=f"{model_type.title()} model {model_id} ({model_without_tenant.name}) belongs to a different tenant (model tenant: {model_without_tenant.tenant_id}, workspace tenant: {tenant_id}). The model is not public and cannot be accessed from this workspace." + ) + else: + # Model doesn't exist at all + logger.warning( + "Model not found", + extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms} + ) + raise ModelNotFoundError( + model_id=model_id, + model_type=model_type, + config_id=config_id, + workspace_id=workspace_id, + message=f"{model_type.title()} model {model_id} not found" + ) if not model.is_active: logger.warning( diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py index 092f0b51..dd8a460e 100644 --- a/api/app/core/workflow/nodes/assigner/config.py +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -22,7 +22,7 @@ class AssignmentItem(BaseModel): ) value: Any = Field( - ..., + default=None, description="Value(s) to assign to the variable(s)", ) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8eb31fb4..e7007884 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -29,7 +29,7 @@ class WorkflowState(TypedDict): # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list - looping: bool + looping: Annotated[bool, lambda x, y: x and y] # Input variables (passed from configured variables) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) @@ -534,7 +534,7 @@ class BaseNode(ABC): return edge return None - def _render_template(self, template: str, state: WorkflowState | None) -> str: + def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str: """渲染模板 支持的变量命名空间: @@ -568,7 +568,8 @@ class BaseNode(ABC): template=template, variables=variables, node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars() + system_vars=pool.get_all_system_vars(), + struct=struct ) def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index fcf65717..445ddd9a 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -1,6 +1,6 @@ from typing import Any -from pydantic import Field, BaseModel +from pydantic import Field, BaseModel, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType @@ -27,6 +27,16 @@ class CycleVariable(BaseNodeConfig): description="Initial or current value of the loop variable" ) + @field_validator("input_type", mode="before") + @classmethod + def lower_input_type(cls, v): + if isinstance(v, str): + try: + return ValueInputType(v.lower()) + except ValueError: + raise ValueError(f"Invalid input_type: {v}") + return v + class ConditionDetail(BaseModel): operator: ComparisonOperator = Field( @@ -45,10 +55,20 @@ class ConditionDetail(BaseModel): ) input_type: ValueInputType = Field( - ..., + default=ValueInputType.CONSTANT, description="Input type of the loop variable" ) + @field_validator("input_type", mode="before") + @classmethod + def lower_input_type(cls, v): + if isinstance(v, str): + try: + return ValueInputType(v.lower()) + except ValueError: + raise ValueError(f"Invalid input_type: {v}") + return v + class ConditionsConfig(BaseModel): """Configuration for loop condition evaluation""" diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index efc62dc5..6230345c 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -37,7 +37,7 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: - output = self._render_template(output_template, state) + output = self._render_template(output_template, state, struct=False) else: output = "工作流已完成" diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index fbbbf845..c294bb11 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -93,5 +93,5 @@ class HttpErrorHandle(StrEnum): class ValueInputType(StrEnum): - VARIABLE = "Variable" - CONSTANT = "Constant" + VARIABLE = "variable" + CONSTANT = "constant" diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index 6bb7baaf..9b41d9f2 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel): content_type = info.data.get("content_type") if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData): raise ValueError("When content_type is 'form-data', data must be of type HttpFormData") - elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict): - raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object") + elif content_type in [HttpContentType.JSON] and not isinstance(v, str): + raise ValueError("When content_type is JSON, data must be of type str") + elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict): + raise ValueError("When content_type is x-www-form-urlencoded, data must be an object(dict)") elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str): raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)") return v diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 55919998..2e5de796 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode): return {} case HttpContentType.JSON: content["json"] = json.loads(self._render_template( - json.dumps(self.typed_config.body.data), state + self.typed_config.body.data, state )) case HttpContentType.FROM_DATA: data = {} @@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode): retries -= 1 if retries > 0: await asyncio.sleep(self.typed_config.retry.retry_interval / 1000) + elif self.typed_config.error_handle.method == HttpErrorHandle.NONE: + raise e + except Exception as e: + raise RuntimeError(f"HTTP request node exception: {e}") else: match self.typed_config.error_handle.method: - case HttpErrorHandle.NONE: - logger.warning( - f"Node {self.node_id}: HTTP request failed, returning error response" - ) - return HttpRequestNodeOutput( - body="", - status_code=resp.status_code, - headers=resp.headers, - ).model_dump() case HttpErrorHandle.DEFAULT: logger.warning( f"Node {self.node_id}: HTTP request failed, returning default result" @@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode): f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) return "ERROR" + raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 4dcb00d1..3e5ea22a 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -23,10 +23,20 @@ class ConditionDetail(BaseModel): ) input_type: ValueInputType = Field( - ..., + default=ValueInputType.CONSTANT, description="Value input type for comparison" ) + @field_validator("input_type", mode="before") + @classmethod + def lower_input_type(cls, v): + if isinstance(v, str): + try: + return ValueInputType(v.lower()) + except ValueError: + raise ValueError(f"Invalid input_type: {v}") + return v + class ConditionBranchConfig(BaseModel): """Configuration for a conditional branch""" diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index fd5864a8..8c6d222f 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -71,7 +71,10 @@ class IfElseNode(BaseNode): for expression in case_branch.expressions: pattern = r"\{\{\s*(.*?)\s*\}\}" left_string = re.sub(pattern, r"\1", expression.left).strip() - left_value = self.get_variable(left_string, state) + try: + left_value = self.get_variable(left_string, state) + except KeyError: + left_value = None evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( self.get_variable_pool(state), expression.left, diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index e12c6224..5a6b2a7f 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -203,15 +203,20 @@ class KnowledgeRetrievalNode(BaseNode): rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, indices=indices, score_threshold=kb_config.similarity_threshold) - # Deduplicate hybrid retrieval results + # Deduplicate hy brid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) + if not unique_rs: + continue vector_service.reranker = self.get_reranker_model() rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) case _: raise RuntimeError("Unknown retrieval type") + if not rs: + return [] vector_service.reranker = self.get_reranker_model() + # TODO:其他重排序方式支持 final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) - return [chunk.model_dump() for chunk in final_rs] + return [chunk.page_content for chunk in final_rs] diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index da94482b..8498fc38 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -1,5 +1,7 @@ """LLM 节点配置""" +from typing import Any + from pydantic import BaseModel, Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType @@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti class MessageConfig(BaseModel): """消息配置""" - + role: str = Field( ..., description="消息角色:system, user, assistant" ) - + content: str = Field( ..., description="消息内容,支持模板变量,如:{{ sys.message }}" ) - + @field_validator("role") @classmethod def validate_role(cls, v: str) -> str: @@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig): 1. 简单模式:使用 prompt 字段 2. 消息模式:使用 messages 字段(推荐) """ - + model_id: str = Field( ..., description="模型配置 ID" ) - + + context: Any = Field( + default="", + description="上下文" + ) + # 简单模式 prompt: str | None = Field( default=None, description="提示词模板(简单模式),支持变量引用" ) - + # 消息模式(推荐) messages: list[MessageConfig] | None = Field( default=None, description="消息列表(消息模式),支持多轮对话" ) - + # 模型参数 temperature: float | None = Field( default=0.7, @@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig): le=2.0, description="温度参数,控制输出的随机性" ) - + max_tokens: int | None = Field( default=1000, ge=1, le=32000, description="最大生成 token 数" ) - + top_p: float | None = Field( default=None, ge=0.0, le=1.0, description="Top-p 采样参数" ) - + frequency_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="频率惩罚" ) - + presence_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="存在惩罚" ) - + # 输出变量定义 output_variables: list[VariableDefinition] = Field( default_factory=lambda: [ @@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig): ], description="输出变量定义(自动生成,通常不需要修改)" ) - + @field_validator("messages", "prompt") @classmethod def validate_input_mode(cls, v, info): """验证输入模式:prompt 和 messages 至少有一个""" # 这个验证在 model_validator 中更合适 return v - + class Config: json_schema_extra = { "examples": [ diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 65826d84..5fb86ae2 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -5,15 +5,17 @@ LLM 节点实现 """ import logging +import re from typing import Any from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.db import get_db_context from app.models import ModelType from app.services.model_service import ModelConfigService - + from app.core.exceptions import BusinessException from app.core.error_codes import BizCode @@ -63,8 +65,16 @@ class LLMNode(BaseNode): - user/human: 用户消息(HumanMessage) - ai/assistant: AI 消息(AIMessage) """ - - def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: + + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = LLMNodeConfig(**self.config) + + def _render_context(self, message, state): + context = f"{self._render_template(self.typed_config.context, state)}" + return re.sub(r"{{context}}", context, message) + + def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: @@ -76,15 +86,16 @@ class LLMNode(BaseNode): # 1. 处理消息格式(优先使用 messages) messages_config = self.config.get("messages") - + if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: role = msg_config.get("role", "user").lower() content_template = msg_config.get("content", "") + content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) - + # 根据角色创建对应的消息对象 if role == "system": messages.append(SystemMessage(content=content)) @@ -95,7 +106,7 @@ class LLMNode(BaseNode): else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append(HumanMessage(content=content)) - + prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -106,17 +117,17 @@ class LLMNode(BaseNode): model_id = self.config.get("model_id") if not model_id: raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") - + # 3. 在 with 块内完成所有数据库操作和数据提取 with get_db_context() as db: config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - - if not config: + + if not config: raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) - + if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - + # 在 Session 关闭前提取所有需要的数据 api_config = config.api_keys[0] model_name = api_config.model_name @@ -124,26 +135,26 @@ class LLMNode(BaseNode): api_key = api_config.api_key api_base = api_config.api_base model_type = config.type - + # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True extra_params = {"streaming": stream} if stream else {} - + llm = RedBearLLM( RedBearModelConfig( model_name=model_name, - provider=provider, + provider=provider, api_key=api_key, base_url=api_base, extra_params=extra_params - ), + ), type=ModelType(model_type) ) - + logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") - + return llm, prompt_or_messages - + async def execute(self, state: WorkflowState) -> AIMessage: """非流式执行 LLM 调用 @@ -153,10 +164,10 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - llm, prompt_or_messages = self._prepare_llm(state,True) - + llm, prompt_or_messages = self._prepare_llm(state, True) + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") - + # 调用 LLM(支持字符串或消息列表) response = await llm.ainvoke(prompt_or_messages) # 提取内容 @@ -164,16 +175,16 @@ class LLMNode(BaseNode): content = response.content else: content = str(response) - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") - + # 返回 AIMessage(包含响应元数据) return response if isinstance(response, AIMessage) else AIMessage(content=content) - + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录)""" _, prompt_or_messages = self._prepare_llm(state) - + return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ @@ -186,13 +197,13 @@ class LLMNode(BaseNode): "max_tokens": self.config.get("max_tokens") } } - + def _extract_output(self, business_result: Any) -> str: """从 AIMessage 中提取文本内容""" if isinstance(business_result, AIMessage): return business_result.content return str(business_result) - + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """从 AIMessage 中提取 token 使用情况""" if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): @@ -204,7 +215,7 @@ class LLMNode(BaseNode): "total_tokens": usage.get('total_tokens', 0) } return None - + async def execute_stream(self, state: WorkflowState): """流式执行 LLM 调用 @@ -215,26 +226,26 @@ class LLMNode(BaseNode): 文本片段(chunk)或完成标记 """ from langgraph.config import get_stream_writer - + llm, prompt_or_messages = self._prepare_llm(state, True) - + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") - + # 检查是否有注入的 End 节点前缀配置 writer = get_stream_writer() end_prefix = getattr(self, '_end_node_prefix', None) - + logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") if end_prefix: logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") - + if end_prefix: # 渲染前缀(可能包含其他变量) try: rendered_prefix = self._render_template(end_prefix, state) logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") - + # 提前发送 End 节点的前缀(使用 "message" 类型) writer({ "type": "message", # End 相关的内容都是 message 类型 @@ -246,12 +257,12 @@ class LLMNode(BaseNode): }) except Exception as e: logger.warning(f"渲染/发送 End 节点前缀失败: {e}") - + # 累积完整响应 full_response = "" last_chunk = None chunk_count = 0 - + # 调用 LLM(流式,支持字符串或消息列表) async for chunk in llm.astream(prompt_or_messages): # 提取内容 @@ -259,18 +270,18 @@ class LLMNode(BaseNode): content = chunk.content else: content = str(chunk) - + # 只有当内容不为空时才处理 if content: full_response += content last_chunk = chunk chunk_count += 1 - + # 流式返回每个文本片段 yield content - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - + # 构建完整的 AIMessage(包含元数据) if isinstance(last_chunk, AIMessage): final_message = AIMessage( @@ -279,6 +290,6 @@ class LLMNode(BaseNode): ) else: final_message = AIMessage(content=full_response) - + # yield 完成标记 yield {"__final__": True, "result": final_message} diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 09c9fc68..bb2366f6 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode): return await MemoryAgentService().read_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, search_switch=self.typed_config.search_switch, history=[], @@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode): return await MemoryAgentService().write_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, db=db, storage_type="neo4j", diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index fc856aee..ab6ad3e1 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -387,6 +387,11 @@ class ArrayComparisonOperator(ConditionBase): return self.right_value not in self.left_value +class NoneObjectComparisonOperator(ConditionBase): + def __getattr__(self, name): + return lambda *args, **kwargs: False + + CompareOperatorInstance = Union[ StringComparisonOperator, NumberComparisonOperator, @@ -405,6 +410,7 @@ class ConditionExpressionResolver: float: NumberComparisonOperator, list: ArrayComparisonOperator, dict: ObjectComparisonOperator, + type(None): NoneObjectComparisonOperator } @classmethod diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 67f53801..b0f2c28d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" @@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode): f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" ) # 若分类列表为空,返回默认unknown分支,否则返回CASE1 - return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" + if category_count > 0: + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } try: llm = self._get_llm_instance() @@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - return f"CASE{category_names.index(category) + 1}" + return { + "class_name": category, + "output": f"CASE{category_names.index(category) + 1}", + } except Exception as e: logger.error( f"节点 {self.node_id} 分类执行异常:{str(e)}", @@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode): ) # 异常时返回默认分支,保证工作流容错性 if category_count > 0: - return DEFAULT_EMPTY_QUESTION_CASE - return "unknown" + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } diff --git a/api/app/core/workflow/nodes/tool/config.py b/api/app/core/workflow/nodes/tool/config.py index 487efae2..d3b1a644 100644 --- a/api/app/core/workflow/nodes/tool/config.py +++ b/api/app/core/workflow/nodes/tool/config.py @@ -1,4 +1,6 @@ from pydantic import Field +from typing import Any + from app.core.workflow.nodes.base_config import BaseNodeConfig @@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig): """工具节点配置""" tool_id: str = Field(..., description="工具ID") - tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") + tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 993a3804..e1b5f380 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -1,5 +1,5 @@ import logging -import uuid +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -9,6 +9,8 @@ from app.db import get_db_read logger = logging.getLogger(__name__) +TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}") + class ToolNode(BaseNode): """工具节点""" @@ -25,25 +27,33 @@ class ToolNode(BaseNode): # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: - workflow_id = self.get_variable("sys.workflow_id", state) - if workflow_id: + workspace_id = self.get_variable("sys.workspace_id", state) + if workspace_id: from app.repositories.tool_repository import ToolRepository with get_db_read() as db: - tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id) if not tenant_id: - tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097") - # logger.error(f"节点 {self.node_id} 缺少租户ID") - # return {"error": "缺少租户ID"} + logger.error(f"节点 {self.node_id} 缺少租户ID") + return { + "success": False, + "data": "缺少租户ID" + } # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - rendered_value = self._render_template(param_template, state) + if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, state) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + # 非模板参数(数字/布尔/普通字符串)直接保留原值 + rendered_value = param_template rendered_parameters[param_name] = rendered_value logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") - print(self.typed_config.tool_id) # 执行工具 with get_db_read() as db: @@ -54,7 +64,7 @@ class ToolNode(BaseNode): tenant_id=tenant_id, user_id=user_id ) - print(result) + if result.success: logger.info(f"节点 {self.node_id} 工具执行成功") return { @@ -66,7 +76,7 @@ class ToolNode(BaseNode): logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") return { "success": False, - "error": result.error, + "data": result.error, "error_code": result.error_code, "execution_time": result.execution_time } \ No newline at end of file diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py index df6053b0..198a3322 100644 --- a/api/app/core/workflow/template_renderer.py +++ b/api/app/core/workflow/template_renderer.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class TemplateRenderer: """模板渲染器""" - + def __init__(self, strict: bool = True): """初始化渲染器 @@ -25,13 +25,13 @@ class TemplateRenderer: undefined=StrictUndefined if strict else Undefined, autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML ) - + def render( - self, - template: str, - variables: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + self, + template: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> str: """渲染模板 @@ -69,40 +69,40 @@ class TemplateRenderer: # variables 的结构:{"sys": {...}, "conv": {...}} sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} - + context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} + "conv": conv_vars, # 会话变量:{{conv.user_name}} + "node": node_outputs, # 节点输出:{{node.node_1.output}} "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) } - + # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} # 将所有节点输出添加到顶层上下文 if node_outputs: context.update(node_outputs) - + # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} if conv_vars: context.update(conv_vars) - + context["nodes"] = node_outputs or {} # 旧语法兼容 - + try: tmpl = self.env.from_string(template) return tmpl.render(**context) - + except TemplateSyntaxError as e: logger.error(f"模板语法错误: {template}, 错误: {e}") raise ValueError(f"模板语法错误: {e}") - + except UndefinedError as e: logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") raise ValueError(f"未定义的变量: {e}") - + except Exception as e: logger.error(f"模板渲染异常: {template}, 错误: {e}") raise ValueError(f"模板渲染失败: {e}") - + def validate(self, template: str) -> list[str]: """验证模板语法 @@ -121,14 +121,14 @@ class TemplateRenderer: ['模板语法错误: ...'] """ errors = [] - + try: self.env.from_string(template) except TemplateSyntaxError as e: errors.append(f"模板语法错误: {e}") except Exception as e: errors.append(f"模板验证失败: {e}") - + return errors @@ -137,14 +137,16 @@ _default_renderer = TemplateRenderer(strict=True) def render_template( - template: str, - variables: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + template: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None, + struct: bool = True ) -> str: """渲染模板(便捷函数) Args: + struct: 渲染模式 template: 模板字符串 variables: 用户变量 node_outputs: 节点输出 @@ -162,7 +164,8 @@ def render_template( ... ) '请分析: 这是一段文本' """ - return _default_renderer.render(template, variables, node_outputs, system_vars) + renderer = TemplateRenderer(strict=struct) + return renderer.render(template, variables, node_outputs, system_vars) def validate_template(template: str) -> list[str]: diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 00358d91..6daf415d 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -87,10 +87,11 @@ class WorkflowValidator: return graphs @classmethod - def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: + def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]: """验证工作流配置 Args: + publish: 发布验证标识 workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 Returns: @@ -114,7 +115,7 @@ class WorkflowValidator: graphs = cls.get_subgraph(workflow_config) logger.info(graphs) - for graph in graphs: + for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) variables = graph.get("variables", []) @@ -125,10 +126,11 @@ class WorkflowValidator: elif len(start_nodes) > 1: errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") - # 2. 验证 end 节点(至少一个) - end_nodes = [n for n in nodes if n.get("type") == NodeType.END] - if len(end_nodes) == 0: - errors.append("工作流必须至少有一个 end 节点") + if index == len(graphs) - 1: + # 2. 验证 主图end 节点(至少一个) + end_nodes = [n for n in nodes if n.get("type") == NodeType.END] + if len(end_nodes) == 0: + errors.append("工作流必须至少有一个 end 节点") # 3. 验证节点 ID 唯一性 node_ids = [n.get("id") for n in nodes] @@ -159,15 +161,17 @@ class WorkflowValidator: elif target not in node_id_set: errors.append(f"边 #{i} 的 target 节点不存在: {target}") - # 6. 验证所有节点可达(从 start 节点出发) - if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( - start_nodes[0]["id"], - edges - ) - unreachable = node_id_set - reachable - if unreachable: - errors.append(f"以下节点无法从 start 节点到达: {unreachable}") + if publish: + # 仅在发布时验证所有节点可达 + # 6. 验证所有节点可达(从 start 节点出发) + if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 + reachable = WorkflowValidator._get_reachable_nodes( + start_nodes[0]["id"], + edges + ) + unreachable = node_id_set - reachable + if unreachable: + errors.append(f"以下节点无法从 start 节点到达: {unreachable}") # 7. 检测循环依赖(非 loop 节点) if not errors: # 只有在前面验证通过时才检查循环 @@ -288,7 +292,7 @@ class WorkflowValidator: (is_valid, errors): 是否有效和错误列表 """ # 先执行基础验证 - is_valid, errors = WorkflowValidator.validate(workflow_config) + is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True) if not is_valid: return False, errors diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 01dad24e..189876a5 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -6,6 +6,7 @@ from .document_model import Document from .file_model import File from .generic_file_model import GenericFile from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey +from .memory_short_model import ShortTermMemory, LongTermMemory from .knowledgeshare_model import KnowledgeShare from .app_model import App from .agent_app_config_model import AgentConfig @@ -25,6 +26,7 @@ from .tool_model import ( ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) +from .memory_perceptual_model import MemoryPerceptualModel __all__ = [ "Tenants", @@ -67,9 +69,12 @@ __all__ = [ "BuiltinToolConfig", "CustomToolConfig", "MCPToolConfig", + "ShortTermMemory", + "LongTermMemory", "ToolExecution", "ToolType", "ToolStatus", "AuthType", - "ExecutionStatus" + "ExecutionStatus", + "MemoryPerceptualModel" ] diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 373de92c..0a7a5935 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -3,7 +3,10 @@ import uuid from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship + +from app.base.type import PydanticType from app.db import Base +from app.schemas import ModelParameters class AgentConfig(Base): @@ -17,14 +20,17 @@ class AgentConfig(Base): # Agent 行为配置 system_prompt = Column(Text, nullable=True, comment="系统提示词") default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID") - + # 结构化配置(直接存储 JSON) - model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)") + # model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)") + model_parameters = Column(PydanticType(ModelParameters), nullable=True, + comment="模型参数配置(temperature、max_tokens等)") + knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") tools = Column(JSON, default=dict, nullable=True, comment="工具配置") - + # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等") @@ -41,4 +47,4 @@ class AgentConfig(Base): parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents") def __repr__(self): - return f"" \ No newline at end of file + return f"" diff --git a/api/app/models/conversation_model.py b/api/app/models/conversation_model.py index e7f9e8c4..4011247f 100644 --- a/api/app/models/conversation_model.py +++ b/api/app/models/conversation_model.py @@ -3,6 +3,7 @@ """ import uuid import datetime + from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, Integer, Text, JSON from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -25,56 +26,69 @@ class Conversation(Base): __tablename__ = "conversations" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - + # 关联信息 app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="应用ID") workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, comment="工作空间ID") user_id = Column(String, nullable=True, comment="用户ID(外部系统)") - + # 会话信息 title = Column(String(255), comment="会话标题") summary = Column(Text, comment="会话摘要") - + # 会话类型:True=草稿会话(使用草稿配置),False=发布会话(使用发布配置) is_draft = Column(Boolean, default=True, nullable=False, comment="是否为草稿会话") - + # 配置快照:保存创建会话时的完整配置,用于审计和问题追溯 config_snapshot = Column(JSON, comment="配置快照(Agent配置、模型配置等)") - + # 统计信息 message_count = Column(Integer, default=0, comment="消息数量") - + # 状态 is_active = Column(Boolean, default=True, nullable=False, comment="是否活跃") - + # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - + # 关联关系 app = relationship("App", back_populates="conversations") workspace = relationship("Workspace") messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") +class ConversationDetail(Base): + __tablename__ = "conversation_details" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id")) + + theme = Column(String, comment="会话主题") + summary = Column(String, comment="会话摘要") + takeaways = Column(JSON, comment="会话要点") + question = Column(JSON, comment="用户问题") + info_score = Column(Integer, comment="会话信息量评分") + + class Message(Base): """消息表""" __tablename__ = "messages" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - + # 关联信息 conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False, comment="会话ID") - + # 消息内容 role = Column(String(20), nullable=False, comment="角色: user/assistant/system") content = Column(Text, nullable=False, comment="消息内容") - + # 元数据(避免使用 metadata 保留字) meta_data = Column(JSON, comment="消息元数据(如模型、token使用等)") - + # 时间戳 created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - + # 关联关系 conversation = relationship("Conversation", back_populates="messages") diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py index 67d789ea..06f87cb2 100644 --- a/api/app/models/data_config_model.py +++ b/api/app/models/data_config_model.py @@ -25,7 +25,6 @@ class DataConfig(Base): llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") - llm = Column(String, nullable=True, comment="LLM模型配置ID") # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") diff --git a/api/app/models/forgetting_cycle_history_model.py b/api/app/models/forgetting_cycle_history_model.py new file mode 100644 index 00000000..6c4f8208 --- /dev/null +++ b/api/app/models/forgetting_cycle_history_model.py @@ -0,0 +1,40 @@ +""" +遗忘周期历史记录模型 + +用于存储每次遗忘周期执行的历史数据,支持趋势分析和可视化。 +""" + +import uuid +from datetime import datetime +from sqlalchemy import Column, Integer, String, Float, DateTime, Index +from sqlalchemy.dialects.postgresql import UUID +from app.db import Base + + +class ForgettingCycleHistory(Base): + """遗忘周期历史记录表""" + + __tablename__ = "forgetting_cycle_history" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True, comment="主键ID") + end_user_id = Column(String(255), nullable=False, comment="终端用户ID") + execution_time = Column(DateTime, nullable=False, default=datetime.now, comment="执行时间") + merged_count = Column(Integer, default=0, comment="本次成功融合的节点对数") + failed_count = Column(Integer, default=0, comment="本次融合失败的节点对数") + average_activation_value = Column(Float, nullable=True, comment="平均激活值") + total_nodes = Column(Integer, default=0, comment="总节点数") + low_activation_nodes = Column(Integer, default=0, comment="低于遗忘阈值的节点总数(包含已融合、失败和待处理的)") + duration_seconds = Column(Float, nullable=True, comment="执行耗时(秒)") + trigger_type = Column(String(50), default="manual", comment="触发类型: manual/scheduled") + + # 创建索引以优化查询 + __table_args__ = ( + Index('idx_end_user_time', 'end_user_id', 'execution_time'), + Index('idx_execution_time', 'execution_time'), + ) + + def __repr__(self): + return ( + f"" + ) diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py new file mode 100644 index 00000000..59eb0222 --- /dev/null +++ b/api/app/models/memory_perceptual_model.py @@ -0,0 +1,40 @@ +import datetime +import uuid +from enum import IntEnum + +from sqlalchemy import Column, ForeignKey, Integer, DateTime, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB + +from app.db import Base + + +class PerceptualType(IntEnum): + VISION = 1 + AUDIO = 2 + TEXT = 3 + CONVERSATION = 4 + + +class FileStorageType(IntEnum): + LOCAL = 1 + REMOTE = 2 + + +class MemoryPerceptualModel(Base): + __tablename__ = "memory_perceptual" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True) + + perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型") + + storage_service = Column(Integer, default=0, comment="存储服务类型") + file_path = Column(String, nullable=False, comment="文件路径") + file_name = Column(String, nullable=False, comment="文件名称") + file_ext = Column(String, nullable=False, comment="文件后缀名") + + summary = Column(String, comment="摘要") + meta_data = Column(JSONB, comment="元信息") + + created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间") diff --git a/api/app/models/memory_short_model.py b/api/app/models/memory_short_model.py new file mode 100644 index 00000000..6c3b1920 --- /dev/null +++ b/api/app/models/memory_short_model.py @@ -0,0 +1,60 @@ +""" +记忆模型 - 短期记忆和长期记忆表 +""" +import uuid +import datetime +from sqlalchemy import Column, String, DateTime, Text, JSON +from sqlalchemy.dialects.postgresql import UUID + +from app.db import Base + + +class ShortTermMemory(Base): + """短期记忆表 + + 用于存储临时的对话记忆,通常保存较短时间 + """ + __tablename__ = "memory_short_term" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID") + + # 用户信息 + end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID") + + # 对话内容 + messages = Column(Text, nullable=False, comment="用户消息内容") + aimessages = Column(Text, nullable=True, comment="AI回复消息内容") + + # 搜索开关 + search_switch = Column(String(50), nullable=True, comment="搜索开关状态") + + # 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}] + retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间") + + def __repr__(self): + return f"" + + +class LongTermMemory(Base): + """长期记忆表 + + 用于存储重要的对话记忆,长期保存 + """ + __tablename__ = "memory_long_term" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID") + + # 用户信息 + end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID") + + # 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}] + retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/api/app/models/multi_agent_model.py b/api/app/models/multi_agent_model.py index 2b41d9ee..544ddb27 100644 --- a/api/app/models/multi_agent_model.py +++ b/api/app/models/multi_agent_model.py @@ -8,15 +8,15 @@ from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship +from app.base.type import PydanticType from app.db import Base from app.schemas import ModelParameters class OrchestrationMode(StrEnum): - """图标类型枚举""" - SEQUENTIAL = "sequential" - PARALLEL = "parallel" - CONDITIONAL = "conditional" + """协作模式枚举""" + COLLABORATION = "collaboration" # 协作模式:Agent 之间可以相互 handoff + SUPERVISOR = "supervisor" # 监督模式:由主 Agent 统一调度子 Agent class AggregationStrategy(StrEnum): """图标类型枚举""" @@ -24,27 +24,27 @@ class AggregationStrategy(StrEnum): VOTE = "vote" PRIORITY = "priority" -class PydanticType(TypeDecorator): - impl = JSON +# class PydanticType(TypeDecorator): +# impl = JSON - def __init__(self, pydantic_model: type[BaseModel]): - super().__init__() - self.model = pydantic_model +# def __init__(self, pydantic_model: type[BaseModel]): +# super().__init__() +# self.model = pydantic_model - def process_bind_param(self, value, dialect): - # 入库:Model -> dict - if value is None: - return None - if isinstance(value, self.model): - return value.dict() - return value # 已经是 dict 也放行 +# def process_bind_param(self, value, dialect): +# # 入库:Model -> dict +# if value is None: +# return None +# if isinstance(value, self.model): +# return value.dict() +# return value # 已经是 dict 也放行 - def process_result_value(self, value, dialect): - # 出库:dict -> Model - if value is None: - return None - # return self.model.parse_obj(value) # pydantic v1 - return self.model.model_validate(value) # pydantic v2 +# def process_result_value(self, value, dialect): +# # 出库:dict -> Model +# if value is None: +# return None +# # return self.model.parse_obj(value) # pydantic v1 +# return self.model.model_validate(value) # pydantic v2 class MultiAgentConfig(Base): """多 Agent 配置表""" @@ -66,8 +66,8 @@ class MultiAgentConfig(Base): orchestration_mode = Column( String(20), nullable=False, - default="conditional", - comment="协作模式: sequential|parallel|conditional|loop" + default="collaboration", + comment="协作模式: collaboration(协作)| supervisor(监督)" ) # 子 Agent 列表 diff --git a/api/app/repositories/conversation_repository.py b/api/app/repositories/conversation_repository.py new file mode 100644 index 00000000..eb5d3c61 --- /dev/null +++ b/api/app/repositories/conversation_repository.py @@ -0,0 +1,317 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, desc, func +from sqlalchemy.orm import Session + +from app.core.exceptions import ResourceNotFoundException +from app.core.logging_config import get_db_logger +from app.models import Conversation, Message +from app.models.conversation_model import ConversationDetail + +logger = get_db_logger() + + +class ConversationRepository: + """Repository for Conversation entity, encapsulating CRUD operations.""" + + def __init__(self, db: Session): + self.db = db + + def create_conversation( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + user_id: Optional[str] = None, + title: Optional[str] = None, + is_draft: bool = False, + config_snapshot: Optional[dict] = None + ) -> Conversation: + """ + Create a new conversation record. + + Args: + app_id: Application ID the conversation belongs to. + workspace_id: Workspace ID where the conversation is created. + user_id: Optional user ID associated with the conversation. + title: Optional conversation title. Defaults to "New Conversation". + is_draft: Whether the conversation is a draft. + config_snapshot: Optional configuration snapshot. + + Returns: + Conversation: Newly created Conversation instance. + """ + conversation = Conversation( + app_id=app_id, + workspace_id=workspace_id, + user_id=user_id, + title=title or "New Conversation", + is_draft=is_draft, + config_snapshot=config_snapshot + ) + self.db.add(conversation) + return conversation + + def get_conversation_by_conversation_id( + self, + conversation_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None + ) -> Conversation: + """ + Retrieve a conversation by its ID, optionally filtered by workspace. + + Args: + conversation_id: The UUID of the conversation. + workspace_id: Optional workspace UUID to filter the conversation. + + Raises: + ResourceNotFoundException: If conversation does not exist. + + Returns: + Conversation: The matching Conversation instance. + """ + logger.info(f"Fetching conversation: {conversation_id}") + + stmt = select(Conversation).where(Conversation.id == conversation_id) + + if workspace_id: + stmt = stmt.where(Conversation.workspace_id == workspace_id) + + conversation = self.db.scalars(stmt).first() + + if not conversation: + logger.warning(f"Conversation not found: {conversation_id}") + raise ResourceNotFoundException("Conversation", str(conversation_id)) + + logger.info(f"Conversation fetched successfully: {conversation_id}") + return conversation + + def get_conversation_by_user_id( + self, + user_id: uuid.UUID, + workspace_id: uuid.UUID = None, + limit: int = 10, + is_activate: bool = True + ) -> list[Conversation]: + """ + Retrieve recent conversations for a specific user. + + This method queries conversations associated with the given user ID, + optionally scoped to a specific workspace. Results are ordered by the + most recently updated conversations and limited to a fixed number. + + Args: + user_id (uuid.UUID): Unique identifier of the user. + workspace_id (uuid.UUID, optional): Workspace scope for the query. + If provided, only conversations under this workspace will be returned. + limit (int): Maximum number of conversations to return. + Defaults to 10. + is_activate (bool): Convsersation State limit + + Returns: + list[Conversation]: A list of conversation entities ordered by + last updated time (descending). + """ + logger.info(f"Fetching conversation by user_id: {user_id}") + + stmt = select(Conversation).where( + Conversation.user_id == str(user_id), + Conversation.is_active.is_(is_activate) + ) + + if workspace_id: + stmt = stmt.where(Conversation.workspace_id == workspace_id) + + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.limit(limit) + + convsersations = list(self.db.scalars(stmt).all()) + logger.info( + "Conversation fetched successfully", + extra={ + "user_id": str(user_id), + "workspace_id": str(workspace_id), + } + ) + return convsersations + + def list_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + user_id: Optional[str] = None, + is_draft: Optional[bool] = None, + page: int = 1, + pagesize: int = 20 + ) -> tuple[list[Conversation], int]: + """ + List conversations with optional filters and pagination. + + Args: + app_id: Application ID filter. + workspace_id: Workspace ID filter. + user_id: Optional user ID filter. + is_draft: Optional draft status filter. + page: Page number (1-based). + pagesize: Number of items per page. + + Returns: + Tuple[List[Conversation], int]: List of Conversation instances and total count. + """ + stmt = select(Conversation).where( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.is_active.is_(True) + ) + + if user_id: + stmt = stmt.where(Conversation.user_id == str(user_id)) + + if is_draft is not None: + stmt = stmt.where(Conversation.is_draft == is_draft) + + # Calculate total number of records + total = int(self.db.execute( + select(func.count()).select_from(stmt.subquery()) + ).scalar_one()) + + # Apply pagination + stmt = stmt.order_by(desc(Conversation.updated_at)) + stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) + + conversations = list(self.db.scalars(stmt).all()) + + logger.info( + "Listed conversations successfully", + extra={ + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "returned": len(conversations), + "total": total + } + ) + return conversations, total + + def soft_delete_conversation_by_conversation_id( + self, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID, + ): + """ + Soft delete a conversation by setting is_active to False. + + Args: + conversation_id: The UUID of the conversation. + workspace_id: Workspace ID for verification. + """ + conversation = self.get_conversation_by_conversation_id( + conversation_id, + workspace_id + ) + conversation.is_active = False + + def get_conversation_detail( + self, + conversation_id: uuid.UUID + ) -> ConversationDetail | None: + """ + Retrieve the detail of a conversation by its ID. + + Args: + conversation_id (UUID): The unique identifier of the conversation. + + Returns: + ConversationDetail or None: The conversation detail object if found, + otherwise None. + + Notes: + - This method queries the database but does not modify it. + - The caller is responsible for handling the case where None is returned. + """ + stmt = select(ConversationDetail).where( + ConversationDetail.conversation_id == conversation_id + ) + detail = self.db.scalars(stmt).first() + return detail + + def add_conversation_detail( + self, + conversation_detail: ConversationDetail, + ): + """ + Add a new conversation detail record to the database session. + + Args: + conversation_detail (ConversationDetail): The ORM object representing + the conversation detail to add. + + Returns: + ConversationDetail: The same object added to the session. + + Notes: + - This method only adds the object to the current session. + - It does not commit the transaction; commit/rollback is handled + by the caller. + - Useful for batch operations or transactional control. + """ + self.db.add(conversation_detail) + return conversation_detail + + +class MessageRepository: + """Repository for Message entity, encapsulating CRUD operations.""" + + def __init__(self, db: Session): + self.db = db + + def add_message(self, message: Message) -> Message: + """ + Add a new message record to the conversation. + + Args: + message (Message): The Message ORM object to be added. + + Returns: + Message: The same message object added to the conversation. + + Notes: + - This method only adds the object to the current conversation. + - It does not commit the transaction; commit/rollback should be handled + by the caller. + - Useful for transactional control or batch operations. + """ + self.db.add(message) + return message + + def get_message_by_conversation_id( + self, + conversation_id: uuid.UUID, + limit: Optional[int] = None + ) -> list[Message]: + """ + Retrieve messages by conversation ID. + + Args: + conversation_id: The UUID of the conversation. + limit: Optional limit on the number of messages returned. + + Returns: + List[Message]: List of Message instances. + """ + stmt = select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at) + + if limit: + stmt = stmt.limit(limit) + + messages = list(self.db.scalars(stmt).all()) + + logger.info( + "Fetched messages successfully", + extra={ + "conversation_id": str(conversation_id), + "returned": len(messages) + } + ) + return messages diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index e5fe35ba..7843acc2 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -327,7 +327,7 @@ class DataConfigRepository: # 更新字段映射 field_mapping = { # 模型选择 - "llm_id": "llm", + "llm_id": "llm_id", "embedding_id": "embedding_id", "rerank_id": "rerank_id", # 记忆萃取引擎 diff --git a/api/app/repositories/forgetting_cycle_history_repository.py b/api/app/repositories/forgetting_cycle_history_repository.py new file mode 100644 index 00000000..9c84b859 --- /dev/null +++ b/api/app/repositories/forgetting_cycle_history_repository.py @@ -0,0 +1,105 @@ +""" +遗忘周期历史记录仓储 + +提供遗忘周期历史记录的数据访问操作。 +""" + +from typing import List, Optional +from datetime import datetime, timedelta +from sqlalchemy.orm import Session +from sqlalchemy import desc, and_ + +from app.models.forgetting_cycle_history_model import ForgettingCycleHistory + + +class ForgettingCycleHistoryRepository: + """遗忘周期历史记录仓储类""" + + def create( + self, + db: Session, + end_user_id: str, + execution_time: datetime, + merged_count: int, + failed_count: int, + average_activation_value: Optional[float], + total_nodes: int, + low_activation_nodes: int, + duration_seconds: float, + trigger_type: str = "manual" + ) -> ForgettingCycleHistory: + """ + 创建历史记录 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + execution_time: 执行时间 + merged_count: 融合节点数 + failed_count: 失败节点数 + average_activation_value: 平均激活值 + total_nodes: 总节点数 + low_activation_nodes: 低激活值节点数 + duration_seconds: 执行耗时 + trigger_type: 触发类型 + + Returns: + ForgettingCycleHistory: 创建的历史记录 + """ + history = ForgettingCycleHistory( + end_user_id=end_user_id, + execution_time=execution_time, + merged_count=merged_count, + failed_count=failed_count, + average_activation_value=average_activation_value, + total_nodes=total_nodes, + low_activation_nodes=low_activation_nodes, + duration_seconds=duration_seconds, + trigger_type=trigger_type + ) + + db.add(history) + db.commit() + db.refresh(history) + + return history + + def get_recent_by_end_user( + self, + db: Session, + end_user_id: str + ) -> List[ForgettingCycleHistory]: + """ + 获取指定终端用户的所有历史记录(按时间降序排列) + + 注意:此方法返回所有历史记录,调用方需要自行处理日期分组和数量限制。 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + + Returns: + List[ForgettingCycleHistory]: 历史记录列表,按时间降序排列 + """ + return db.query(ForgettingCycleHistory).filter( + ForgettingCycleHistory.end_user_id == end_user_id + ).order_by(ForgettingCycleHistory.execution_time.desc()).all() + + def get_latest_by_end_user( + self, + db: Session, + end_user_id: str + ) -> Optional[ForgettingCycleHistory]: + """ + 获取指定终端用户的最新历史记录 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + + Returns: + Optional[ForgettingCycleHistory]: 最新历史记录 + """ + return db.query(ForgettingCycleHistory).filter( + ForgettingCycleHistory.end_user_id == end_user_id + ).order_by(desc(ForgettingCycleHistory.execution_time)).first() diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py new file mode 100644 index 00000000..8415c2d0 --- /dev/null +++ b/api/app/repositories/memory_perceptual_repository.py @@ -0,0 +1,156 @@ +import uuid +from datetime import datetime +from typing import List, Tuple, Optional + +from sqlalchemy import and_, desc +from sqlalchemy.orm import Session + +from app.core.logging_config import get_db_logger +from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType +from app.schemas.memory_perceptual_schema import PerceptualQuerySchema + +db_logger = get_db_logger() + + +class MemoryPerceptualRepository: + """Data Access Layer for perceptual memory""" + + def __init__(self, db: Session): + self.db = db + + # ==================== Create and update ==================== + def create_perceptual_memory( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + file_path: str, + file_name: str, + file_ext: str, + summary: Optional[str] = None, + meta_data: Optional[dict] = None, + storage_service: FileStorageType = FileStorageType.LOCAL + + ) -> MemoryPerceptualModel: + + """Create perceptual memory""" + + db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, " + f"type={perceptual_type}, file={file_name}") + + try: + perceptual_memory = MemoryPerceptualModel( + end_user_id=end_user_id, + perceptual_type=perceptual_type, + storage_service=storage_service, + file_path=file_path, + file_name=file_name, + file_ext=file_ext, + summary=summary, + meta_data=meta_data, + created_time=datetime.now() + ) + + self.db.add(perceptual_memory) + self.db.flush() + + db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}") + return perceptual_memory + + except Exception as e: + db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}") + raise + + # ==================== Query ==================== + def get_count_by_user_id( + self, + end_user_id: uuid.UUID, + ): + db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}") + + try: + count = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id + ).count() + return count + except Exception as e: + db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}") + raise + + def get_count_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + ): + db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}") + + try: + count = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id, + MemoryPerceptualModel.perceptual_type == perceptual_type + ).count() + return count + except Exception as e: + db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}") + raise + + def get_timeline( + self, + end_user_id: uuid.UUID, + query: PerceptualQuerySchema + ) -> Tuple[int, List[MemoryPerceptualModel]]: + """Get the timeline of a user's perceptual memories""" + db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}") + + try: + base_query = self.db.query(MemoryPerceptualModel).filter( + MemoryPerceptualModel.end_user_id == end_user_id + ) + + if query.filter.type is not None: + base_query = base_query.filter( + MemoryPerceptualModel.perceptual_type == query.filter.type + ) + + total_count = base_query.count() + + memories = base_query.order_by( + desc(MemoryPerceptualModel.created_time) + ).offset( + (query.page - 1) * query.page_size + ).limit(query.page_size).all() + + db_logger.info( + f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}") + return total_count, memories + + except Exception as e: + db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}") + raise + + def get_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType, + limit: int = 10, + offset: int = 0 + ) -> List[MemoryPerceptualModel]: + """Get memories by perceptual type""" + db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}") + + try: + memories = self.db.query(MemoryPerceptualModel).filter( + and_( + MemoryPerceptualModel.end_user_id == end_user_id, + MemoryPerceptualModel.perceptual_type == perceptual_type + ) + ).order_by( + desc(MemoryPerceptualModel.created_time) + ).offset(offset).limit(limit).all() + + db_logger.debug(f"Query by type succeeded: count={len(memories)}") + return memories + + except Exception as e: + db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, " + f"type={perceptual_type} - {str(e)}") + raise diff --git a/api/app/repositories/memory_short_repository.py b/api/app/repositories/memory_short_repository.py new file mode 100644 index 00000000..9a6e39c6 --- /dev/null +++ b/api/app/repositories/memory_short_repository.py @@ -0,0 +1,503 @@ +""" +记忆仓储模块 - 短期记忆和长期记忆的数据访问层 +""" +from sqlalchemy.orm import Session +from typing import List, Optional, Dict, Any +import uuid +import datetime + +from app.models.memory_short_model import ShortTermMemory, LongTermMemory +from app.core.logging_config import get_db_logger + +# 获取数据库专用日志器 +db_logger = get_db_logger() + + +class ShortTermMemoryRepository: + """短期记忆仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory: + """创建短期记忆记录 + + Args: + end_user_id: 终端用户ID + messages: 用户消息内容 + aimessages: AI回复消息内容 + search_switch: 搜索开关状态 + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + ShortTermMemory: 创建的短期记忆对象 + """ + try: + memory = ShortTermMemory( + end_user_id=end_user_id, + messages=messages, + aimessages=aimessages, + search_switch=search_switch, + retrieved_content=retrieved_content or [] + ) + + self.db.add(memory) + self.db.commit() + self.db.refresh(memory) + + db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}") + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建短期记忆记录时出错: {str(e)}") + raise + + def count_by_user_id(self,end_user_id: str) -> int: + """根据ID获取短期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + Optional[ShortTermMemory]: 记忆对象,如果不存在则返回None + """ + try: + count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.end_user_id == end_user_id) + .count() + ) + db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}") + + return count + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}") + raise + + + def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]: + """获取用户最新的短期记忆记录 + + Args: + end_user_id: 终端用户ID + limit: 返回记录数限制,默认5条 + + Returns: + List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序 + """ + try: + # 使用复合索引 ix_memory_short_term_user_time 优化查询 + memories = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.end_user_id == end_user_id) + .order_by(ShortTermMemory.created_at.desc()) + .limit(limit) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}") + raise + + def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]: + """获取用户最近指定小时内的短期记忆记录 + + Args: + end_user_id: 终端用户ID + hours: 时间范围(小时),默认24小时 + + Returns: + List[ShortTermMemory]: 记忆记录列表,按创建时间倒序 + """ + try: + cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours) + + # 使用复合索引 ix_memory_short_term_user_time 优化查询 + memories = ( + self.db.query(ShortTermMemory) + .filter( + ShortTermMemory.end_user_id == end_user_id, + ShortTermMemory.created_at >= cutoff_time + ) + .order_by(ShortTermMemory.created_at.desc()) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}") + raise + + def delete_by_id(self, memory_id: uuid.UUID) -> bool: + """删除指定ID的短期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + bool: 删除成功返回True,否则返回False + """ + try: + deleted_count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.id == memory_id) + .delete(synchronize_session=False) + ) + + self.db.commit() + + if deleted_count > 0: + db_logger.info(f"成功删除短期记忆记录 {memory_id}") + return True + else: + db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def delete_old_memories(self, days: int = 7) -> int: + """删除指定天数之前的短期记忆记录 + + Args: + days: 保留天数,默认7天 + + Returns: + int: 删除的记录数 + """ + try: + cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days) + + deleted_count = ( + self.db.query(ShortTermMemory) + .filter(ShortTermMemory.created_at < cutoff_time) + .delete(synchronize_session=False) + ) + + self.db.commit() + + db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录") + return deleted_count + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}") + raise + + def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory: + """创建或更新短期记忆记录 + + 根据 end_user_id、messages 和 aimessages 查找现有记录: + - 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content + - 如果没有找到匹配的记录,则创建新记录 + + Args: + end_user_id: 终端用户ID + messages: 用户消息内容 + aimessages: AI回复消息内容 + search_switch: 搜索开关状态 + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + ShortTermMemory: 创建或更新的短期记忆对象 + """ + try: + # 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询 + query_filters = [ + ShortTermMemory.end_user_id == end_user_id, + ShortTermMemory.messages == messages + ] + + # 如果 aimessages 不为空,则加入查询条件 + if aimessages is not None: + query_filters.append(ShortTermMemory.aimessages == aimessages) + else: + # 如果 aimessages 为 None,则查找 aimessages 为 NULL 的记录 + query_filters.append(ShortTermMemory.aimessages.is_(None)) + + # 查找现有记录 + existing_memory = ( + self.db.query(ShortTermMemory) + .filter(*query_filters) + .first() + ) + + if existing_memory: + # 更新现有记录 + existing_memory.messages = messages + existing_memory.aimessages = aimessages + existing_memory.search_switch = search_switch + existing_memory.retrieved_content = retrieved_content or [] + + self.db.commit() + self.db.refresh(existing_memory) + + db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}") + return existing_memory + else: + # 创建新记录 + new_memory = ShortTermMemory( + end_user_id=end_user_id, + messages=messages, + aimessages=aimessages, + search_switch=search_switch, + retrieved_content=retrieved_content or [] + ) + + self.db.add(new_memory) + self.db.commit() + self.db.refresh(new_memory) + + db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}") + return new_memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}") + raise + + +class LongTermMemoryRepository: + """长期记忆仓储类""" + + def __init__(self, db: Session): + self.db = db + + def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory: + """创建长期记忆记录 + + Args: + end_user_id: 终端用户ID + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + LongTermMemory: 创建的长期记忆对象 + """ + try: + memory = LongTermMemory( + end_user_id=end_user_id, + retrieved_content=retrieved_content or [] + ) + + self.db.add(memory) + self.db.commit() + self.db.refresh(memory) + + db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}") + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建长期记忆记录时出错: {str(e)}") + raise + + def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]: + """根据ID获取长期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + Optional[LongTermMemory]: 记忆对象,如果不存在则返回None + """ + try: + memory = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.id == memory_id) + .first() + ) + + if memory: + db_logger.debug(f"成功查询到长期记忆记录 {memory_id}") + else: + db_logger.debug(f"未找到长期记忆记录 {memory_id}") + + return memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]: + """根据用户ID获取长期记忆记录列表 + + Args: + end_user_id: 终端用户ID + limit: 返回记录数限制,默认100 + offset: 偏移量,默认0 + + Returns: + List[LongTermMemory]: 记忆记录列表,按创建时间倒序 + """ + try: + # 使用复合索引 ix_memory_long_term_user_time 优化查询 + memories = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .order_by(LongTermMemory.created_at.desc()) + .limit(limit) + .offset(offset) + .all() + ) + + db_logger.info(f"成功查询用户 {end_user_id} 的 {len(memories)} 条长期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}") + raise + + def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]: + """根据内容关键词搜索长期记忆记录 + + Args: + end_user_id: 终端用户ID + keyword: 搜索关键词 + limit: 返回记录数限制,默认50 + + Returns: + List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序 + """ + try: + # 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索 + # 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤 + memories = ( + self.db.query(LongTermMemory) + .filter( + LongTermMemory.end_user_id == end_user_id, + LongTermMemory.retrieved_content.astext.contains(keyword) + ) + .order_by(LongTermMemory.created_at.desc()) + .limit(limit) + .all() + ) + + db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}' 的 {len(memories)} 条长期记忆记录") + return memories + + except Exception as e: + self.db.rollback() + db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}") + raise + + def delete_by_id(self, memory_id: uuid.UUID) -> bool: + """删除指定ID的长期记忆记录 + + Args: + memory_id: 记忆ID + + Returns: + bool: 删除成功返回True,否则返回False + """ + try: + deleted_count = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.id == memory_id) + .delete(synchronize_session=False) + ) + + self.db.commit() + + if deleted_count > 0: + db_logger.info(f"成功删除长期记忆记录 {memory_id}") + return True + else: + db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}") + raise + + def count_by_user_id(self, end_user_id: str) -> int: + """统计用户的长期记忆记录数量 + + Args: + end_user_id: 终端用户ID + + Returns: + int: 记录数量 + """ + try: + count = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .count() + ) + + db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录") + return count + + except Exception as e: + self.db.rollback() + db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}") + raise + + def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]: + """创建或更新长期记忆记录 + + 根据 end_user_id 和 retrieved_content 判断是否需要写入: + - 如果找到相同的 end_user_id 和 retrieved_content,则不写入,返回 None + - 如果没有找到相同的记录,则创建新记录 + + Args: + end_user_id: 终端用户ID + retrieved_content: 检索到的相关内容,格式为[{}, {}] + + Returns: + Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None + """ + try: + retrieved_content = retrieved_content or [] + + # 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户 + # 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较 + existing_memories = ( + self.db.query(LongTermMemory) + .filter(LongTermMemory.end_user_id == end_user_id) + .order_by(LongTermMemory.created_at.desc()) + .limit(100) # 限制查询数量,避免加载过多数据 + .all() + ) + + # 在 Python 中比较 retrieved_content + for memory in existing_memories: + if memory.retrieved_content == retrieved_content: + # 如果找到相同的记录,不写入 + db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}") + return None + + # 如果没有找到相同的记录,创建新记录 + new_memory = LongTermMemory( + end_user_id=end_user_id, + retrieved_content=retrieved_content + ) + + self.db.add(new_memory) + self.db.commit() + self.db.refresh(new_memory) + + db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}") + return new_memory + + except Exception as e: + self.db.rollback() + db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}") + raise + + diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 79466fa0..1e24eeae 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -211,6 +211,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector "dialog_id": s.dialog_id, "chunk_ids": s.chunk_ids, "content": s.content, + "memory_type": s.memory_type, # 添加 memory_type 字段 "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 259b1325..7d77ad4f 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -92,6 +92,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity WHEN entity.description IS NOT NULL AND entity.description <> '' AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description)) THEN entity.description ELSE e.description END, + e.example = CASE + WHEN entity.example IS NOT NULL AND entity.example <> '' + THEN entity.example + ELSE coalesce(e.example, '') + END, e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END, e.aliases = CASE WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0 @@ -121,7 +126,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END, e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END, e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END, - e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END + e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END, + e.is_explicit_memory = CASE WHEN entity.is_explicit_memory IS NOT NULL THEN entity.is_explicit_memory ELSE coalesce(e.is_explicit_memory, false) END RETURN e.id AS uuid """ @@ -722,7 +728,12 @@ SET m += { chunk_ids: summary.chunk_ids, content: summary.content, summary_embedding: summary.summary_embedding, - config_id: summary.config_id + config_id: summary.config_id, + importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END, + activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END, + access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END, + last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END, + access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END } RETURN m.id AS uuid """ @@ -857,3 +868,174 @@ neo4j_query_all = """ """ +'''针对当前节点下扩长的句子,实体和总结''' +Memory_Timeline_ExtractedEntity=""" +MATCH (n)-[r1]-(e)-[r2]-(ms) +WHERE elementId(n) = $id + AND (ms:ExtractedEntity OR ms:MemorySummary) + +RETURN + collect( + DISTINCT + CASE + WHEN ms:ExtractedEntity THEN { + text: ms.name, + created_at: ms.created_at, + type: "情景记忆" + } + END + ) AS ExtractedEntity, + + collect( + DISTINCT + CASE + WHEN ms:MemorySummary THEN { + text: ms.content, + created_at: ms.created_at, + type: "长期沉淀" + } + END + ) AS MemorySummary, + + collect( + DISTINCT { + text: e.statement, + created_at: e.created_at, + type: "情绪记忆" + } + ) AS statement; + + +""" +Memory_Timeline_MemorySummary=""" +MATCH (n)-[r1]-(e)-[r2]-(ms) +WHERE elementId(n) =$id + AND (ms:MemorySummary OR ms:ExtractedEntity) +RETURN + collect( + DISTINCT + CASE + WHEN ms:ExtractedEntity THEN { + text: ms.name, + created_at: ms.created_at + } + END + ) AS ExtractedEntity, + + collect( + DISTINCT + CASE + WHEN n:MemorySummary THEN { + text: n.content, + created_at: n.created_at + } + END + ) AS MemorySummary, + + collect( + DISTINCT { + text: e.statement, + created_at: e.created_at + } + ) AS statement; +""" +Memory_Timeline_Statement=""" +MATCH (n) +WHERE elementId(n) = $id + +CALL { + WITH n + MATCH (n)-[]-(m:ExtractedEntity) + WHERE NOT m:MemorySummary AND NOT m:Chunk + RETURN collect( + DISTINCT { + text: m.name, + created_at: m.created_at, + type: "情景记忆" + } + ) AS ExtractedEntity +} + +CALL { + WITH n + MATCH (n)-[]-(m:MemorySummary) + WHERE NOT m:Chunk + RETURN collect( + DISTINCT { + text: m.content, + created_at: m.created_at, + type: "长期沉淀" + } + ) AS MemorySummary +} + +RETURN + ExtractedEntity, + MemorySummary, + { + text: n.statement, + created_at: n.created_at, + type: "情绪记忆" + } AS statement; + + +""" + +'''针对当前节点,主要获取更加完整的句子节点''' +Memory_Space_Emotion_Statement=""" +MATCH (n) +WHERE elementId(n) = $id +RETURN + n.emotion_intensity AS emotion_intensity, + n.created_at AS created_at, + n.emotion_type AS emotion_type, + n.statement AS statement; + +""" +Memory_Space_Emotion_MemorySummary=""" +MATCH (n)-[]-(e) +WHERE elementId(n) = $id + AND EXISTS { + MATCH (e)-[]-(ms) + WHERE ms:MemorySummary OR ms:ExtractedEntity + } +RETURN DISTINCT + e.emotion_intensity AS emotion_intensity, + e.created_at AS created_at, + e.emotion_type AS emotion_type, + e.statement AS statement; +""" +Memory_Space_Emotion_ExtractedEntity=""" +MATCH (n)-[]-(e) +WHERE elementId(n) = $id + AND EXISTS { + MATCH (e)-[]-(ms:ExtractedEntity) + } +RETURN DISTINCT + e.emotion_intensity AS emotion_intensity, + e.created_at AS created_at, + e.emotion_type AS emotion_type, + e.statement AS statement; +""" + +'''获取实体''' + +Memory_Space_User=""" +MATCH (n)-[r]->(m) +WHERE n.group_id = $group_id AND m.name="用户" +return DISTINCT elementId(m) as id +""" +Memory_Space_Entity=""" +MATCH (n)-[]-(m) +WHERE elementId(m) = $id AND m.entity_type = "Person" +RETURN +DISTINCT m.name as name,m.group_id as group_id +""" +Memory_Space_Associative=""" +MATCH (u)-[]-(x)-[]-(h) +WHERE elementId(u) = $user_id + AND elementId(h) = $id +RETURN DISTINCT + x.statement as statement,x.created_at as created_at +""" + diff --git a/api/app/repositories/neo4j/entity_repository.py b/api/app/repositories/neo4j/entity_repository.py index cb18feca..f4ca35c8 100644 --- a/api/app/repositories/neo4j/entity_repository.py +++ b/api/app/repositories/neo4j/entity_repository.py @@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]): # 处理 ACT-R 属性 - 确保字段存在且有默认值 n['importance_score'] = n.get('importance_score', 0.5) n['activation_value'] = n.get('activation_value') - n['access_history'] = n.get('access_history', []) + n['access_history'] = n.get('access_history') or [] n['last_access_time'] = n.get('last_access_time') n['access_count'] = n.get('access_count', 0) diff --git a/api/app/repositories/neo4j/memory_summary_repository.py b/api/app/repositories/neo4j/memory_summary_repository.py new file mode 100644 index 00000000..fc743f33 --- /dev/null +++ b/api/app/repositories/neo4j/memory_summary_repository.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +"""Memory Summary Repository Module + +This module provides data access functionality for MemorySummary nodes. + +Classes: + MemorySummaryRepository: Repository for managing MemorySummary CRUD operations +""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + +class MemorySummaryRepository(BaseNeo4jRepository): + """Memory Summary Repository + + Manages CRUD operations for MemorySummary nodes. + Provides methods to query summaries by group_id, user_id, and time ranges. + + Attributes: + connector: Neo4j connector instance + node_label: Node label, fixed as "MemorySummary" + """ + + def __init__(self, connector: Neo4jConnector): + """Initialize memory summary repository + + Args: + connector: Neo4j connector instance + """ + super().__init__(connector, "MemorySummary") + + def _map_to_dict(self, node_data: Dict) -> Dict[str, Any]: + """Map node data to dictionary format + + Args: + node_data: Node data returned from Neo4j query + + Returns: + Dict[str, Any]: Memory summary data dictionary + """ + # Extract node data from query result + n = node_data.get('n', node_data) + + # Handle datetime fields + if isinstance(n.get('created_at'), str): + n['created_at'] = datetime.fromisoformat(n['created_at']) + + return dict(n) + + async def find_by_group_id( + self, + group_id: str, + limit: int = 1000, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None + ) -> List[Dict[str, Any]]: + """Query memory summaries by group_id + + Args: + group_id: Group ID to filter by + limit: Maximum number of results to return + start_date: Optional start date filter + end_date: Optional end date filter + + Returns: + List[Dict[str, Any]]: List of memory summary dictionaries + """ + query = f""" + MATCH (n:{self.node_label}) + WHERE n.group_id = $group_id + """ + + params = {"group_id": group_id, "limit": limit} + + # Add date range filters if provided + if start_date: + query += " AND n.created_at >= $start_date" + params["start_date"] = start_date + + if end_date: + query += " AND n.created_at <= $end_date" + params["end_date"] = end_date + + query += """ + RETURN n + ORDER BY n.created_at DESC + LIMIT $limit + """ + + results = await self.connector.execute_query(query, **params) + return [self._map_to_dict(r) for r in results] + + async def find_by_user_id( + self, + user_id: str, + limit: int = 1000, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None + ) -> List[Dict[str, Any]]: + """Query memory summaries by user_id + + Args: + user_id: User ID to filter by + limit: Maximum number of results to return + start_date: Optional start date filter + end_date: Optional end date filter + + Returns: + List[Dict[str, Any]]: List of memory summary dictionaries + """ + query = f""" + MATCH (n:{self.node_label}) + WHERE n.user_id = $user_id + """ + + params = {"user_id": user_id, "limit": limit} + + # Add date range filters if provided + if start_date: + query += " AND n.created_at >= $start_date" + params["start_date"] = start_date + + if end_date: + query += " AND n.created_at <= $end_date" + params["end_date"] = end_date + + query += """ + RETURN n + ORDER BY n.created_at DESC + LIMIT $limit + """ + + results = await self.connector.execute_query(query, **params) + return [self._map_to_dict(r) for r in results] + + async def find_by_group_and_user( + self, + group_id: str, + user_id: str, + limit: int = 1000, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None + ) -> List[Dict[str, Any]]: + """Query memory summaries by both group_id and user_id + + Args: + group_id: Group ID to filter by + user_id: User ID to filter by + limit: Maximum number of results to return + start_date: Optional start date filter + end_date: Optional end date filter + + Returns: + List[Dict[str, Any]]: List of memory summary dictionaries + """ + query = f""" + MATCH (n:{self.node_label}) + WHERE n.group_id = $group_id AND n.user_id = $user_id + """ + + params = {"group_id": group_id, "user_id": user_id, "limit": limit} + + # Add date range filters if provided + if start_date: + query += " AND n.created_at >= $start_date" + params["start_date"] = start_date + + if end_date: + query += " AND n.created_at <= $end_date" + params["end_date"] = end_date + + query += """ + RETURN n + ORDER BY n.created_at DESC + LIMIT $limit + """ + + results = await self.connector.execute_query(query, **params) + return [self._map_to_dict(r) for r in results] + + async def find_recent_summaries( + self, + group_id: str, + days: int = 7, + limit: int = 1000 + ) -> List[Dict[str, Any]]: + """Query recent memory summaries + + Args: + group_id: Group ID to filter by + days: Number of recent days to query + limit: Maximum number of results to return + + Returns: + List[Dict[str, Any]]: List of memory summary dictionaries + """ + query = f""" + MATCH (n:{self.node_label}) + WHERE n.group_id = $group_id + AND n.created_at >= datetime() - duration({{days: $days}}) + RETURN n + ORDER BY n.created_at DESC + LIMIT $limit + """ + + results = await self.connector.execute_query( + query, + group_id=group_id, + days=days, + limit=limit + ) + return [self._map_to_dict(r) for r in results] + + async def find_by_content_keywords( + self, + group_id: str, + keywords: List[str], + limit: int = 100 + ) -> List[Dict[str, Any]]: + """Query memory summaries by content keywords + + Args: + group_id: Group ID to filter by + keywords: List of keywords to search for in content + limit: Maximum number of results to return + + Returns: + List[Dict[str, Any]]: List of memory summary dictionaries + """ + # Build keyword search conditions + keyword_conditions = [] + params = {"group_id": group_id, "limit": limit} + + for i, keyword in enumerate(keywords): + keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})") + params[f"keyword_{i}"] = keyword + + keyword_filter = " OR ".join(keyword_conditions) + + query = f""" + MATCH (n:{self.node_label}) + WHERE n.group_id = $group_id + AND ({keyword_filter}) + RETURN n + ORDER BY n.created_at DESC + LIMIT $limit + """ + + results = await self.connector.execute_query(query, **params) + return [self._map_to_dict(r) for r in results] + + async def get_summary_count_by_group(self, group_id: str) -> int: + """Get count of memory summaries for a group + + Args: + group_id: Group ID to count summaries for + + Returns: + int: Number of memory summaries + """ + query = f""" + MATCH (n:{self.node_label}) + WHERE n.group_id = $group_id + RETURN count(n) as count + """ + + results = await self.connector.execute_query(query, group_id=group_id) + return results[0]['count'] if results else 0 + \ No newline at end of file diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index 22343e10..cd9f2fac 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): # 处理 ACT-R 属性 - 确保字段存在且有默认值 n['importance_score'] = n.get('importance_score', 0.5) n['activation_value'] = n.get('activation_value') - n['access_history'] = n.get('access_history', []) + n['access_history'] = n.get('access_history') or [] n['last_access_time'] = n.get('last_access_time') n['access_count'] = n.get('access_count', 0) diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 3aa7b16e..257910c3 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -38,6 +38,33 @@ class ToolRepository: return result[0] if result else None + @staticmethod + def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]: + """ + 根据空间ID获取tenant_id + + Args: + db: 数据库会话 + workspace_id: 空间ID + + Returns: + tenant_id或None + """ + from app.models.workspace_model import Workspace + + tenant_id = db.query(Workspace.tenant_id).filter( + Workspace.id == workspace_id + ).scalar() + + if tenant_id is not None and not isinstance(tenant_id, uuid.UUID): + # 兼容数据库中字段类型不匹配的情况(比如存储为字符串) + try: + tenant_id = uuid.UUID(tenant_id) + except (ValueError, TypeError): + return None + + return tenant_id + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 81cd704d..3c00e5a0 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Union from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel): class ToolConfig(BaseModel): + """工具配置""" + enabled: bool = Field(default=False, description="是否启用该工具") + tool_id: Optional[str] = Field(default=None, description="工具ID") + operation: Optional[str] = Field(default=None, description="工具特定配置") + +class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") @@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel): ) # 工具配置 - tools: Dict[str, ToolConfig] = Field( - default_factory=dict, - description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)" + tools: List[ToolConfig] = Field( + default_factory=list, + description="Agent 可用的工具列表" ) @@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel): variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") # 工具配置 - tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") + tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表") # ---------- Output Schemas ---------- @@ -216,7 +222,7 @@ class AgentConfig(BaseModel): variables: List[VariableDefinition] = [] # 工具配置 - tools: Dict[str, ToolConfig] = {} + tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] is_active: bool created_at: datetime.datetime diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 63db6685..6ec9b9b6 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -35,14 +35,14 @@ class ChatRequest(BaseModel): class Message(BaseModel): """消息输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID conversation_id: uuid.UUID role: str content: str meta_data: Optional[Dict[str, Any]] = None created_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -51,7 +51,7 @@ class Message(BaseModel): class Conversation(BaseModel): """会话输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID app_id: uuid.UUID workspace_id: uuid.UUID @@ -63,11 +63,11 @@ class Conversation(BaseModel): is_active: bool created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -84,3 +84,12 @@ class ChatResponse(BaseModel): message: str usage: Optional[Dict[str, Any]] = None elapsed_time: Optional[float] = None + + +# ---------- Conversation Summary Schemas ---------- +class ConversationOut(BaseModel): + theme: str + question: list[str] + summary: str + takeaways: list[str] + info_score: int diff --git a/api/app/schemas/implicit_memory_schema.py b/api/app/schemas/implicit_memory_schema.py new file mode 100644 index 00000000..e1770b18 --- /dev/null +++ b/api/app/schemas/implicit_memory_schema.py @@ -0,0 +1,264 @@ +"""Implicit Memory Schemas + +This module defines the Pydantic schemas for the implicit memory system API. +""" + +import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator + + +class FrequencyPattern(str, Enum): + """Frequency patterns for habits.""" + DAILY = "daily" + WEEKLY = "weekly" + MONTHLY = "monthly" + SEASONAL = "seasonal" + OCCASIONAL = "occasional" + EVENT_TRIGGERED = "event_triggered" + + +# Request Schemas + +class TimeRange(BaseModel): + """Time range for analysis.""" + start_date: datetime.datetime + end_date: datetime.datetime + + @field_validator('end_date') + @classmethod + def end_date_must_be_after_start_date(cls, v, info): + if 'start_date' in info.data and v <= info.data['start_date']: + raise ValueError('end_date must be after start_date') + return v + + @field_serializer("start_date", when_used="json") + def _serialize_start_date(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("end_date", when_used="json") + def _serialize_end_date(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class DateRange(BaseModel): + """Date range for filtering.""" + start_date: Optional[datetime.datetime] = None + end_date: Optional[datetime.datetime] = None + + @field_validator('end_date') + @classmethod + def end_date_must_be_after_start_date(cls, v, info): + if v and 'start_date' in info.data and info.data['start_date'] and v <= info.data['start_date']: + raise ValueError('end_date must be after start_date') + return v + + @field_serializer("start_date", when_used="json") + def _serialize_start_date(self, dt: Optional[datetime.datetime]): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("end_date", when_used="json") + def _serialize_end_date(self, dt: Optional[datetime.datetime]): + return int(dt.timestamp() * 1000) if dt else None + + +class AnalysisConfig(BaseModel): + """Configuration for analysis operations.""" + llm_model_id: Optional[str] = None + batch_size: int = 100 + confidence_threshold: float = 0.5 + include_historical_trends: bool = False + max_conversations: Optional[int] = None + + +# Response Schemas + +class PreferenceTagResponse(BaseModel): + """A user preference tag with detailed context.""" + model_config = ConfigDict(from_attributes=True) + + tag_name: str + confidence_score: float = Field(ge=0.0, le=1.0) + supporting_evidence: List[str] + context_details: str + created_at: datetime.datetime + updated_at: datetime.datetime + conversation_references: List[str] + category: Optional[str] = None + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class DimensionScoreResponse(BaseModel): + """Score for a personality dimension.""" + model_config = ConfigDict(from_attributes=True) + + dimension_name: str + percentage: float = Field(ge=0.0, le=100.0) + evidence: List[str] + reasoning: str + confidence_level: int = Field(ge=0, le=100) + + +class DimensionPortraitResponse(BaseModel): + """Four-dimension personality portrait.""" + model_config = ConfigDict(from_attributes=True) + + user_id: str + creativity: DimensionScoreResponse + aesthetic: DimensionScoreResponse + technology: DimensionScoreResponse + literature: DimensionScoreResponse + analysis_timestamp: datetime.datetime + total_summaries_analyzed: int + historical_trends: Optional[List[Dict[str, Any]]] = None + + @field_serializer("analysis_timestamp", when_used="json") + def _serialize_analysis_timestamp(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class InterestCategoryResponse(BaseModel): + """Interest category with percentage and evidence.""" + model_config = ConfigDict(from_attributes=True) + + category_name: str + percentage: float = Field(ge=0.0, le=100.0) + evidence: List[str] + trending_direction: Optional[str] = None + + +class InterestAreaDistributionResponse(BaseModel): + """Distribution of user interests across four areas.""" + model_config = ConfigDict(from_attributes=True) + + user_id: str + tech: InterestCategoryResponse + lifestyle: InterestCategoryResponse + music: InterestCategoryResponse + art: InterestCategoryResponse + analysis_timestamp: datetime.datetime + total_summaries_analyzed: int + + @property + def total_percentage(self) -> float: + """Calculate total percentage across all interest areas.""" + return self.tech.percentage + self.lifestyle.percentage + self.music.percentage + self.art.percentage + + @field_serializer("analysis_timestamp", when_used="json") + def _serialize_analysis_timestamp(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class BehaviorHabitResponse(BaseModel): + """A behavioral habit identified from conversations.""" + model_config = ConfigDict(from_attributes=True) + + habit_description: str + frequency_pattern: FrequencyPattern + time_context: str + confidence_level: int = Field(ge=0, le=100) + first_observed: datetime.datetime + last_observed: datetime.datetime + is_current: bool = True + specific_examples: List[str] + + @field_serializer("first_observed", when_used="json") + def _serialize_first_observed(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("last_observed", when_used="json") + def _serialize_last_observed(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class UserProfileResponse(BaseModel): + """Comprehensive user profile.""" + model_config = ConfigDict(from_attributes=True) + + user_id: str + preference_tags: List[PreferenceTagResponse] + dimension_portrait: DimensionPortraitResponse + interest_area_distribution: InterestAreaDistributionResponse + behavior_habits: List[BehaviorHabitResponse] + profile_version: int + created_at: datetime.datetime + updated_at: datetime.datetime + total_summaries_analyzed: int + analysis_completeness_score: float = Field(ge=0.0, le=1.0) + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +# Internal/Business Logic Schemas + +class MemorySummary(BaseModel): + """Memory summary from existing storage system.""" + model_config = ConfigDict(from_attributes=True) + + summary_id: str + content: str + timestamp: datetime.datetime + participants: List[str] + summary_type: str + + @field_serializer("timestamp", when_used="json") + def _serialize_timestamp(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class UserMemorySummary(BaseModel): + """Memory summary filtered for specific user content.""" + model_config = ConfigDict(from_attributes=True) + + summary_id: str + user_id: str + user_content: str + timestamp: datetime.datetime + confidence_score: float = Field(ge=0.0, le=1.0) + summary_type: str + + @field_serializer("timestamp", when_used="json") + def _serialize_timestamp(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class SummaryAnalysisResult(BaseModel): + """Result of analyzing memory summaries.""" + model_config = ConfigDict(from_attributes=True) + + user_id: str + preferences: List[PreferenceTagResponse] + dimension_evidence: Dict[str, List[str]] + interest_evidence: Dict[str, List[str]] + habit_evidence: List[Dict[str, Any]] + analysis_timestamp: datetime.datetime + summaries_analyzed: List[str] + + @field_serializer("analysis_timestamp", when_used="json") + def _serialize_analysis_timestamp(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +# Aliases for backward compatibility with existing code +PreferenceTag = PreferenceTagResponse +DimensionScore = DimensionScoreResponse +DimensionPortrait = DimensionPortraitResponse +InterestCategory = InterestCategoryResponse +InterestAreaDistribution = InterestAreaDistributionResponse +BehaviorHabit = BehaviorHabitResponse +UserProfile = UserProfileResponse diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py new file mode 100644 index 00000000..c2e4517e --- /dev/null +++ b/api/app/schemas/memory_perceptual_schema.py @@ -0,0 +1,136 @@ +import uuid +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + +from app.models.memory_perceptual_model import PerceptualType, FileStorageType + + +class PerceptualFilter(BaseModel): + type: PerceptualType | None = Field( + default=None, + description="Perceptual type used for filtering the query; optional" + ) + + +class PerceptualQuerySchema(BaseModel): + filter: PerceptualFilter = Field( + default_factory=lambda: PerceptualFilter(), + description="Query filter containing perceptual type criteria" + ) + + page: int = Field( + default=1, + ge=1, + description="Page number for pagination, starting from 1" + ) + + page_size: int = Field( + default=10, + ge=1, + le=100, + description="Number of records per page, range 1-100" + ) + + +class PerceptualMemoryItem(BaseModel): + """感知记忆项""" + id: uuid.UUID = Field(..., description="Unique memory ID") + perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video") + file_path: str = Field(..., description="File path in the storage service") + file_ext: str = Field(..., description="File extension") + file_name: str = Field(..., description="File name") + summary: Optional[str] = Field(None, description="summary") + storage_type: FileStorageType = Field(..., description="Storage type for file") + created_time: int = Field(..., description="create time") + + class Config: + from_attributes = True + + +class PerceptualTimelineResponse(BaseModel): + """感知记忆时间线响应""" + total: int = Field(..., description="总数量") + page: int = Field(..., description="当前页码") + page_size: int = Field(..., description="每页大小") + total_pages: int = Field(..., description="总页数") + memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表") + + class Config: + from_attributes = True + + +# -------------------------- +# TODO: FileMetaData +# -------------------------- +class Identity(BaseModel): + title: str + filename: str + source: str # upload | crawl | system + author: Optional[str] = None + + +class Semantic(BaseModel): + topic: str + domain: str + difficulty: str # beginner | intermediate | advanced + intent: str # informative | instructional | promotional + sentiment: str # positive | neutral | negative + + +class Content(BaseModel): + summary: str + keywords: list[str] + topic: str + domain: str + + +class Usage(BaseModel): + target_audience: list[str] + use_cases: list[str] + + +class Stats(BaseModel): + duration_sec: Optional[int] = None + char_count: int + word_count: int + + +class Processing(BaseModel): + transcribed: bool + ocr_applied: bool + chunked: bool + vectorized: bool + embedding_model: Optional[str] = None + + +class VideoModal(BaseModel): + scene: list[str] + + +class AudioModal(BaseModel): + speaker_count: int + + +class TextModal(BaseModel): + section_count: int + title: str + first_line: str + + +class Asset(BaseModel): + type: str + modality: str # text | audio | video + format: str # docx | mp3 | mp4 + language: str + encoding: str + + identity: Identity + semantic: Semantic + content: Content + usage: Usage + stats: Stats + processing: Processing + created_at: str + modalities: AudioModal | TextModal | VideoModal diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 24747c34..ca9b29de 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -409,10 +409,9 @@ class ForgettingTriggerRequest(BaseModel): """手动触发遗忘周期请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - group_id: Optional[str] = Field(None, description="组ID(可选,用于过滤特定组的节点)") + group_id: str = Field(..., description="组ID(即终端用户ID,必填)") max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") - config_id: Optional[int] = Field(None, description="配置ID(可选,用于指定遗忘引擎配置)") # TODO 后续group_id更换成enduser_id,自动与config_id关联 ,要删除此行 class ForgettingConfigResponse(BaseModel): @@ -450,15 +449,36 @@ class ForgettingConfigUpdateRequest(BaseModel): forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)") +class ForgettingCycleHistoryPoint(BaseModel): + """遗忘周期历史数据点模型(用于趋势图)""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + date: str = Field(..., description="日期(格式: '1/1', '1/2')") + merged_count: int = Field(..., description="每日融合节点数") + average_activation: Optional[float] = Field(None, description="平均激活值") + total_nodes: int = Field(..., description="总节点数") + execution_time: int = Field(..., description="执行时间(Unix时间戳,秒)") + + +class PendingForgettingNode(BaseModel): + """待遗忘节点模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + node_id: str = Field(..., description="节点ID") + node_type: str = Field(..., description="节点类型:statement/entity/summary") + content_summary: str = Field(..., description="内容摘要") + activation_value: float = Field(..., description="激活值") + last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)") + + class ForgettingStatsResponse(BaseModel): """遗忘引擎统计信息响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") node_distribution: Dict[str, int] = Field(..., description="节点类型分布") - consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果") - nodes_merged_total: int = Field(..., description="累计融合节点对数") - recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录") - timestamp: str = Field(..., description="统计时间(ISO格式)") + recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)") + pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)") + timestamp: int = Field(..., description="统计时间(时间戳)") class ForgettingReportResponse(BaseModel): diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index c666a2c0..c0d72cdd 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -66,9 +66,9 @@ class MultiAgentConfigCreate(BaseModel): master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") orchestration_mode: str = Field( - ..., - pattern="^(sequential|parallel|conditional|loop)$", - description="编排模式:sequential|parallel|conditional|loop" + default="collaboration", + pattern="^(collaboration|supervisor)$", + description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") @@ -84,14 +84,15 @@ class MultiAgentConfigUpdate(BaseModel): """更新多 Agent 配置""" master_agent_id: Optional[uuid.UUID] = None master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") - default_model_config_id : uuid.UUID = Field(description="默认模型配置ID") - model_parameters: ModelParameters | None = Field( - default_factory=ModelParameters, + default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") + model_parameters: Optional[ModelParameters] = Field( + None, description="模型参数配置(temperature、max_tokens 等)" ) orchestration_mode: Optional[str] = Field( - None, - pattern="^(sequential|parallel|conditional|loop)$" + default="collaboration", + pattern="^(collaboration|supervisor)$", + description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: Optional[List[SubAgentConfig]] = None routing_rules: Optional[List[RoutingRule]] = None diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py index baabe186..48afe2c3 100644 --- a/api/app/schemas/tool_schema.py +++ b/api/app/schemas/tool_schema.py @@ -154,7 +154,7 @@ class MCPToolConfigSchema(BaseModel): last_health_check: Optional[datetime] = None health_status: str = "unknown" error_message: Optional[str] = None - available_tools: List[str] = Field(default_factory=list) + available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]") class Config: from_attributes = True diff --git a/api/app/schemas/user_memory_schema.py b/api/app/schemas/user_memory_schema.py new file mode 100644 index 00000000..796ad72f --- /dev/null +++ b/api/app/schemas/user_memory_schema.py @@ -0,0 +1,43 @@ +""" +用户记忆相关的请求和响应模型 +""" +from pydantic import BaseModel, Field +from typing import Optional + + +class EpisodicMemoryOverviewRequest(BaseModel): + """情景记忆总览查询请求""" + + end_user_id: str = Field(..., description="终端用户ID") + time_range: str = Field( + default="all", + description="时间范围筛选,可选值:all, today, this_week, this_month" + ) + episodic_type: str = Field( + default="all", + description="情景类型筛选,可选值:all, conversation, project_work, learning, decision, important_event" + ) + title_keyword: Optional[str] = Field( + default=None, + description="标题关键词,用于模糊搜索(可选)" + ) + + +class EpisodicMemoryDetailsRequest(BaseModel): + """情景记忆详情查询请求""" + + end_user_id: str = Field(..., description="终端用户ID") + summary_id: str = Field(..., description="情景记忆摘要ID") + + +class ExplicitMemoryOverviewRequest(BaseModel): + """显性记忆总览查询请求""" + + end_user_id: str = Field(..., description="终端用户ID") + + +class ExplicitMemoryDetailsRequest(BaseModel): + """显性记忆详情查询请求""" + + end_user_id: str = Field(..., description="终端用户ID") + memory_id: str = Field(..., description="记忆ID(情景记忆或语义记忆的ID)") diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 262c1c04..094aade8 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -2,14 +2,14 @@ Agent 配置格式转换器 用于将 Pydantic 模型转换为数据库存储格式 """ -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Union from app.schemas.app_schema import ( KnowledgeRetrievalConfig, MemoryConfig, VariableDefinition, ToolConfig, AgentConfigCreate, - AgentConfigUpdate, + AgentConfigUpdate, ToolOldConfig, ) @@ -47,10 +47,7 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: - result["tools"] = { - name: tool.model_dump() - for name, tool in config.tools.items() - } + result["tools"] = [tool.model_dump() for tool in config.tools] return result @@ -60,7 +57,7 @@ class AgentConfigConverter: knowledge_retrieval: Optional[Dict[str, Any]], memory: Optional[Dict[str, Any]], variables: Optional[list], - tools: Optional[Dict[str, Any]], + tools: Optional[Union[list, Dict[str, Any]]], ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -80,13 +77,18 @@ class AgentConfigConverter: "knowledge_retrieval": None, "memory": MemoryConfig(enabled=True), "variables": [], - "tools": {}, + "tools": [], } # 1. 解析模型参数配置 if model_parameters: from app.schemas.app_schema import ModelParameters - result["model_parameters"] = ModelParameters(**model_parameters) + if isinstance(model_parameters, ModelParameters): + result["model_parameters"] = model_parameters + elif isinstance(model_parameters, dict): + result["model_parameters"] = ModelParameters(**model_parameters) + else: + result["model_parameters"] = ModelParameters() # 2. 解析知识库检索配置 if knowledge_retrieval: @@ -108,9 +110,12 @@ class AgentConfigConverter: # 5. 解析工具配置 if tools: - result["tools"] = { - name: ToolConfig(**tool_data) - for name, tool_data in tools.items() - } + if isinstance(tools, list): + result["tools"] = [ToolConfig(**tool_config) for tool_config in tools] + else: + result["tools"] = { + name: ToolOldConfig(**tool_data) + for name, tool_data in tools.items() + } return result diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index efcf318d..ec046013 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -9,7 +9,13 @@ from fastapi import Depends from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.db import get_db, get_db_context +from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig +from app.services.tool_service import ToolService +from app.repositories.tool_repository import ToolRepository from app.db import get_db from app.models import MultiAgentConfig, AgentConfig from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole @@ -18,6 +24,7 @@ from app.services.draft_run_service import create_knowledge_retrieval_tool, crea from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -40,6 +47,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" @@ -65,6 +73,24 @@ class AppChatService: # 准备工具列表 tools = [] + # 获取工具服务 + tool_service = ToolService(self.db) + + # 从配置中获取启用的工具 + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): + for tool_config in config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: @@ -83,20 +109,21 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.tools - # web_search_choice = web_tools.get("web_search", {}) - # web_search_enable = web_search_choice.get("enabled", False) - # if web_search == True: - # if web_search_enable == True: - # search_tool = create_web_search_tool({}) - # tools.append(search_tool) - # - # logger.debug( - # "已添加网络搜索工具", - # extra={ - # "tool_count": len(tools) - # } - # ) + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) # 获取模型参数 model_parameters = config.model_parameters @@ -170,6 +197,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + workspace_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """聊天(流式)""" @@ -184,7 +212,7 @@ class AppChatService: model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") + system_prompt = config.system_prompt if variables: system_prompt_rendered = render_prompt_message( system_prompt, @@ -196,8 +224,25 @@ class AppChatService: # 准备工具列表 tools = [] + # 获取工具服务 + tool_service = ToolService(self.db) + + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): + for tool_config in config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") + knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] @@ -208,29 +253,30 @@ class AppChatService: # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.get("memory", {}) + memory_config = config.memory if memory_config.get("enabled") and user_id: memory_flag = True memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) # 获取模型参数 - model_parameters = config.get("model_parameters", {}) + model_parameters = config.model_parameters # 创建 LangChain Agent agent = LangChainAgent( @@ -479,7 +525,9 @@ class AppChatService: self, message: str, conversation_id: uuid.UUID, - config: AgentConfig, + config: WorkflowConfig, + app_id: uuid.UUID, + workspace_id: uuid.UUID, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, @@ -488,281 +536,159 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" + workflow_service = WorkflowService(self.db) - start_time = time.time() - config_id = None + input_data = {"message":message, "variables": variables, + "conversation_id": str(conversation_id)} + inconfig = workflow_service.get_workflow_config(app_id) - if variables is None: - variables = {} + # 2. 创建执行记录 + execution = workflow_service.create_execution( + workflow_config_id=inconfig.id, + app_id=app_id, + trigger_type="manual", + triggered_by=None, + conversation_id=conversation_id, + input_data=input_data + ) - # 获取模型配置ID - model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + + # 5. 执行工作流 + from app.core.workflow.executor import execute_workflow + + try: + # 更新状态为运行中 + workflow_service.update_execution_status(execution.execution_id, "running") + + result = await execute_workflow( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id=str(workspace_id), + user_id=user_id ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - memory_flag = False - if memory == True: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } + # 更新执行结果 + if result.get("status") == "completed": + workflow_service.update_execution_status( + execution.execution_id, + "completed", + output_data=result.get("node_outputs", {}) + ) + else: + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=result.get("error") ) - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) + # 返回增强的响应结构 + return { + "execution_id": execution.execution_id, + "status": result.get("status"), + "output": result.get("output"), # 最终输出(字符串) + "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) + "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID + "error_message": result.get("error"), + "elapsed_time": result.get("elapsed_time"), + "token_usage": result.get("token_usage") + } - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - - ) - - # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) + except Exception as e: + logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) + raise BusinessException( + code=BizCode.INTERNAL_ERROR, + message=f"工作流执行失败: {str(e)}" ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 调用 Agent - result = await agent.chat( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag - ) - - # 保存消息 - self.conversation_service.save_conversation_messages( - conversation_id=conversation_id, - user_message=message, - assistant_message=result["content"] - ) - - elapsed_time = time.time() - start_time - - return { - "conversation_id": conversation_id, - "message": result["content"], - "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }), - "elapsed_time": elapsed_time - } async def workflow_chat_stream( self, message: str, conversation_id: uuid.UUID, - config: AgentConfig, + config: WorkflowConfig, + app_id: uuid.UUID, + workspace_id: uuid.UUID, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: """聊天(流式)""" + workflow_service = WorkflowService(self.db) + input_data = {"message": message, "variables": variables, + "conversation_id": str(conversation_id)} + inconfig = workflow_service.get_workflow_config(app_id) + # 2. 创建执行记录 + execution = workflow_service.create_execution( + workflow_config_id=inconfig.id, + app_id=app_id, + trigger_type="manual", + triggered_by=None, + conversation_id=conversation_id, + input_data=input_data + ) + + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + + # 5. 流式执行工作流 try: - start_time = time.time() - config_id = None + # 更新状态为运行中 + workflow_service.update_execution_status(execution.execution_id, "running") - if variables is None: - variables = {} - # 获取模型配置ID - model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables - ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - memory_flag = False - if memory: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) - - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - streaming=True - ) - - # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - - # 流式调用 Agent - full_content = "" - async for chunk in agent.chat_stream( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag + # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) + async for event in workflow_service._run_workflow_stream( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id=str(workspace_id), + user_id=user_id ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + # 直接转发 executor 的事件(已经是正确的格式) + yield event - elapsed_time = time.time() - start_time - - # 保存消息 - self.conversation_service.add_message( - conversation_id=conversation_id, - role="user", - content=message - ) - - self.conversation_service.add_message( - conversation_id=conversation_id, - role="assistant", - content=full_content, - meta_data={ - "model": api_key_obj.model_name, - "usage": {} - } - ) - - # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} - yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - - logger.info( - "流式聊天完成", - extra={ - "conversation_id": str(conversation_id), - "elapsed_time": elapsed_time, - "message_length": len(full_content) - } - ) - - except (GeneratorExit, asyncio.CancelledError): - # 生成器被关闭或任务被取消,正常退出 - logger.debug("流式聊天被中断") - raise except Exception as e: - logger.error(f"流式聊天失败: {str(e)}", exc_info=True) + logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield { + "event": "error", + "data": { + "execution_id": execution.execution_id, + "error": str(e) + } + } + # ==================== 依赖注入函数 ==================== def get_app_chat_service( diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 95bcc07a..6d5204f8 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -21,6 +21,7 @@ from app.core.exceptions import ( BusinessException, ) from app.core.logging_config import get_business_logger +from app.core.workflow.validator import WorkflowValidator from app.db import get_db from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig from app.models.app_model import AppStatus, AppType @@ -31,6 +32,8 @@ from app.schemas.workflow_schema import WorkflowConfigUpdate from app.services.agent_config_converter import AgentConfigConverter from app.models import AppShare, Workspace from app.services.model_service import ModelApiKeyService +from app.services.workflow_service import WorkflowService +from app.utils.app_config_utils import model_parameters_to_dict # 获取业务日志器 logger = get_business_logger() @@ -201,27 +204,28 @@ class AppService: "多智能体配置未激活,无法运行", BizCode.AGENT_CONFIG_MISSING ) - if not multi_agent_config.default_model_config_id: - # # 2. 检查主 Agent 配置 - if not multi_agent_config.master_agent_id: - raise BusinessException( - "未配置主 Agent,无法运行", - BizCode.AGENT_CONFIG_MISSING - ) + if multi_agent_config.orchestration_mode == "supervisor": + if not multi_agent_config.default_model_config_id: + # # 2. 检查主 Agent 配置 + if not multi_agent_config.master_agent_id: + raise BusinessException( + "未配置主 Agent,无法运行", + BizCode.AGENT_CONFIG_MISSING + ) - master_agent_release = self.db.get(AppRelease, multi_agent_config.master_agent_id) - if not master_agent_release: - raise BusinessException( - f"主 Agent 配置不存在: {multi_agent_config.master_agent_id}", - BizCode.AGENT_CONFIG_MISSING - ) + master_agent_release = self.db.get(AppRelease, multi_agent_config.master_agent_id) + if not master_agent_release: + raise BusinessException( + f"主 Agent 配置不存在: {multi_agent_config.master_agent_id}", + BizCode.AGENT_CONFIG_MISSING + ) - # 检查主 Agent 的模型配置 - multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id + # 检查主 Agent 的模型配置 + multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id - model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id) - if not model_api_key: - raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) + model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id) + if not model_api_key: + raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) # 3. 检查子 Agent 配置 @@ -273,12 +277,7 @@ class AppService: ) logger.info( - "多智能体配置检查通过", - extra={ - "app_id": str(app_id), - "master_agent_id": str(multi_agent_config.master_agent_id), - "sub_agent_count": len(multi_agent_config.sub_agents) - } + "多智能体配置检查通过" ) def _create_agent_config( @@ -305,7 +304,7 @@ class AppService: knowledge_retrieval=storage_data.get("knowledge_retrieval"), memory=storage_data.get("memory"), variables=storage_data.get("variables", []), - tools=storage_data.get("tools", {}), + tools=storage_data.get("tools", []), is_active=True, created_at=now, updated_at=now, @@ -687,7 +686,7 @@ class AppService: knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, memory=source_config.memory.copy() if source_config.memory else None, variables=source_config.variables.copy() if source_config.variables else [], - tools=source_config.tools.copy() if source_config.tools else {}, + tools=source_config.tools.copy() if source_config.tools else [], is_active=True, created_at=now, updated_at=now, @@ -813,6 +812,37 @@ class AppService: ) return items, int(total) + def get_apps_by_ids( + self, + app_ids: List[str], + workspace_id: uuid.UUID + ) -> List[App]: + """根据ID列表获取应用 + + Args: + app_ids: 应用ID列表 + workspace_id: 工作空间ID(用于权限验证) + + Returns: + List[App]: 应用列表 + """ + if not app_ids: + return [] + + # 转换字符串ID为UUID + try: + uuid_ids = [uuid.UUID(app_id) for app_id in app_ids] + except ValueError: + return [] + + # 查询本工作空间的应用 + 分享给本工作空间的应用 + stmt = select(App).where( + App.id.in_(uuid_ids), + App.workspace_id == workspace_id + ) + + return list(self.db.scalars(stmt).all()) + # ==================== Agent 配置管理 ==================== def update_agent_config( @@ -877,7 +907,7 @@ class AppService: # if data.variables is not None: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: - agent_cfg.tools = storage_data.get("tools", {}) + agent_cfg.tools = storage_data.get("tools", []) agent_cfg.updated_at = now @@ -964,7 +994,7 @@ class AppService: "max_history": 10 }, variables=[], - tools={}, + tools=[], is_active=True, created_at=now, updated_at=now, @@ -1177,11 +1207,11 @@ class AppService: config = { "system_prompt": agent_cfg.system_prompt, - "model_parameters": agent_cfg.model_parameters, + "model_parameters": model_parameters_to_dict(agent_cfg.model_parameters), "knowledge_retrieval": agent_cfg.knowledge_retrieval, "memory": agent_cfg.memory, "variables": agent_cfg.variables or [], - "tools": agent_cfg.tools or {}, + "tools": agent_cfg.tools or [], } # config = AgentConfigConverter.from_storage_format(agent_cfg) default_model_config_id = agent_cfg.default_model_config_id @@ -1206,8 +1236,10 @@ class AppService: default_model_config_id = multi_agent_cfg.default_model_config_id # 4. 构建配置快照 + + config = { - "model_parameters":multi_agent_cfg.model_parameters, + "model_parameters": model_parameters_to_dict(multi_agent_cfg.model_parameters), "master_agent_id": str(multi_agent_cfg.master_agent_id), "orchestration_mode": multi_agent_cfg.orchestration_mode, "sub_agents": multi_agent_cfg.sub_agents, @@ -1225,6 +1257,26 @@ class AppService: "orchestration_mode": multi_agent_cfg.orchestration_mode } ) + elif app.type == AppType.WORKFLOW: + service = WorkflowService(self.db) + workflow_cfg = service.get_workflow_config(app_id) + if not workflow_cfg: + raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING) + + config = { + "nodes": workflow_cfg.nodes, + "edges": workflow_cfg.edges, + "variables": workflow_cfg.variables, + "execution_config": workflow_cfg.execution_config, + "triggers": workflow_cfg.triggers + } + + is_valid, errors = WorkflowValidator.validate_for_publish(config) + if not is_valid: + raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING) + logger.info( + "应用发布配置准备完成" + ) now = datetime.datetime.now() version = self._get_next_version(app_id) @@ -2047,6 +2099,16 @@ def list_apps( ) +def get_apps_by_ids( + db: Session, + app_ids: List[str], + workspace_id: uuid.UUID +) -> List[App]: + """根据ID列表获取应用(向后兼容接口)""" + service = AppService(db) + return service.get_apps_by_ids(app_ids, workspace_id) + + # ==================== 向后兼容的函数接口 ==================== async def draft_run( diff --git a/api/app/services/collaborative_orchestrator.py b/api/app/services/collaborative_orchestrator.py index bfb54f65..f01b7e01 100644 --- a/api/app/services/collaborative_orchestrator.py +++ b/api/app/services/collaborative_orchestrator.py @@ -537,7 +537,7 @@ class CollaborativeOrchestrator: }) # 提取 usage - if hasattr(response, 'usage_metadata'): + if hasattr(response, 'usage_metadata') and response.usage_metadata: result["usage"] = { "prompt_tokens": response.usage_metadata.get("input_tokens", 0), "completion_tokens": response.usage_metadata.get("output_tokens", 0), diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 122d0d87..3695a222 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -1,177 +1,287 @@ """会话服务""" import uuid +from datetime import datetime, timedelta from typing import Annotated from typing import Optional, List, Tuple +import json_repair from fastapi import Depends -from sqlalchemy import select, desc +from jinja2 import Template from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.exceptions import ResourceNotFoundException from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig from app.db import get_db -from app.models import Conversation, Message +from app.models import Conversation, Message, User, ModelType +from app.models.conversation_model import ConversationDetail +from app.models.prompt_optimizer_model import RoleType +from app.repositories.conversation_repository import ConversationRepository, MessageRepository +from app.schemas.conversation_schema import ConversationOut +from app.services import workspace_service +from app.services.model_service import ModelConfigService logger = get_business_logger() class ConversationService: - """会话服务""" + """ + Service layer for managing conversations and messages. + Provides methods to create, retrieve, list, and manipulate conversations and messages. + Delegates database operations to repositories. + """ def __init__(self, db: Session): self.db = db + self.conversation_repo = ConversationRepository(db) + self.message_repo = MessageRepository(db) def create_conversation( - self, - app_id: uuid.UUID, - workspace_id: uuid.UUID, - user_id: Optional[str] = None, - title: Optional[str] = None, - is_draft: bool = False, - config_snapshot: Optional[dict] = None + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + user_id: Optional[str] = None, + title: Optional[str] = None, + is_draft: bool = False, + config_snapshot: Optional[dict] = None ) -> Conversation: - """创建会话""" - conversation = Conversation( - app_id=app_id, - workspace_id=workspace_id, - user_id=user_id, - title=title or "新会话", - is_draft=is_draft, - config_snapshot=config_snapshot - ) + """ + Create a new conversation in the system. - self.db.add(conversation) - self.db.commit() - self.db.refresh(conversation) + Args: + app_id (uuid.UUID): The application ID the conversation belongs to. + workspace_id (uuid.UUID): Workspace ID for context. + user_id (Optional[str]): Optional user ID for the conversation owner. + title (Optional[str]): Conversation title. Defaults to 'New Conversation' if not provided. + is_draft (bool): Whether the conversation is a draft. + config_snapshot (Optional[dict]): Optional configuration snapshot. - logger.info( - "创建会话成功", - extra={ - "conversation_id": str(conversation.id), - "app_id": str(app_id), - "workspace_id": str(workspace_id), - "is_draft": is_draft - } - ) + Returns: + Conversation: Newly created Conversation instance. + """ + try: + conversation = self.conversation_repo.create_conversation( + app_id=app_id, + workspace_id=workspace_id, + user_id=user_id, + title=title or "New Conversation", + is_draft=is_draft, + config_snapshot=config_snapshot + ) + self.db.commit() + self.db.refresh(conversation) + + logger.info( + "Create Conversation Success", + extra={ + "conversation_id": str(conversation.id), + "app_id": str(app_id), + "workspace_id": str(workspace_id), + "is_draft": is_draft + } + ) + except Exception as e: + logger.error( + f"Create Conversation Failed - {str(e)}" + ) + self.db.rollback() + raise BusinessException(f"Error create Convsersation", code=BizCode.DB_ERROR) return conversation def get_conversation( - self, - conversation_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None + self, + conversation_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None ) -> Conversation: - """获取会话""" - stmt = select(Conversation).where(Conversation.id == conversation_id) + """ + Retrieve a conversation by its ID. - if workspace_id: - stmt = stmt.where(Conversation.workspace_id == workspace_id) + Args: + conversation_id (uuid.UUID): The conversation UUID. + workspace_id (Optional[uuid.UUID]): Optional workspace UUID to restrict the query. - conversation = self.db.scalars(stmt).first() + Raises: + ResourceNotFoundException: If the conversation does not exist. - if not conversation: - raise ResourceNotFoundException("会话", str(conversation_id)) + Returns: + Conversation: The requested Conversation instance. + """ + conversation = self.conversation_repo.get_conversation_by_conversation_id( + conversation_id=conversation_id, + workspace_id=workspace_id + ) return conversation - def list_conversations( - self, - app_id: uuid.UUID, - workspace_id: uuid.UUID, - user_id: Optional[str] = None, - is_draft: Optional[bool] = None, - page: int = 1, - pagesize: int = 20 - ) -> Tuple[List[Conversation], int]: - """列出会话""" - stmt = select(Conversation).where( - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active == True + def get_user_conversations( + self, + user_id: uuid.UUID + ) -> list[Conversation]: + """ + Retrieve recent conversations for a specific user + + This method delegates persistence logic to the repository layer and + applies service-level defaults (e.g. recent conversation limit). + + Args: + user_id (uuid.UUID): Unique identifier of the user. + + Returns: + list[Conversation]: A list of recent conversation entities. + """ + conversations = self.conversation_repo.get_conversation_by_user_id( + user_id, + limit=10 ) + return conversations - if user_id: - stmt = stmt.where(Conversation.user_id == user_id) + def list_conversations( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + user_id: Optional[str] = None, + is_draft: Optional[bool] = None, + page: int = 1, + pagesize: int = 20 + ) -> Tuple[List[Conversation], int]: + """ + List conversations with optional filters and pagination. - if is_draft is not None: - stmt = stmt.where(Conversation.is_draft == is_draft) + Args: + app_id (uuid.UUID): Application ID filter. + workspace_id (uuid.UUID): Workspace ID filter. + user_id (Optional[str]): Optional user ID filter. + is_draft (Optional[bool]): Optional draft status filter. + page (int): Page number, 1-based. + pagesize (int): Number of items per page. - # 总数 - count_stmt = stmt.with_only_columns(Conversation.id) - total = len(self.db.execute(count_stmt).all()) - - # 分页 - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - - conversations = list(self.db.scalars(stmt).all()) + Returns: + Tuple[List[Conversation], int]: A list of Conversation instances and the total count. + """ + conversations, total = self.conversation_repo.list_conversations( + app_id=app_id, + workspace_id=workspace_id, + user_id=user_id, + is_draft=is_draft, + page=page, + pagesize=pagesize + ) return conversations, total def add_message( - self, - conversation_id: uuid.UUID, - role: str, - content: str, - meta_data: Optional[dict] = None + self, + conversation_id: uuid.UUID, + role: str, + content: str, + meta_data: Optional[dict] = None ) -> Message: - """添加消息""" - message = Message( - conversation_id=conversation_id, - role=role, - content=content, - meta_data=meta_data - ) + """ + Add a message to a conversation using UnitOfWork. - self.db.add(message) + Args: + conversation_id (uuid.UUID): Conversation UUID. + role (str): Role of the message sender ('user' or 'assistant'). + content (str): Message content. + meta_data (Optional[dict]): Optional metadata. - # 更新会话的消息计数和更新时间 - conversation = self.get_conversation(conversation_id) - conversation.message_count += 1 + Returns: + Message: Newly created Message instance. + """ + try: + conversation = self.conversation_repo.get_conversation_by_conversation_id( + conversation_id + ) - # 如果是第一条用户消息,可以用它作为标题 - if conversation.message_count == 1 and role == "user": - conversation.title = content[:50] + ("..." if len(content) > 50 else "") + message = Message( + conversation_id=conversation_id, + role=role, + content=content, + meta_data=meta_data, + ) - self.db.commit() - self.db.refresh(message) + self.message_repo.add_message(message) - return message + conversation.message_count += 1 + + if conversation.message_count == 1 and role == "user": + conversation.title = ( + content[:50] + ("..." if len(content) > 50 else "") + ) + + self.db.commit() + self.db.refresh(message) + + logger.info( + "Message added successfully", + extra={ + "conversation_id": str(conversation_id), + "message_id": str(message.id), + "role": role, + "content_length": len(content), + }, + ) + + return message + except Exception as e: + logger.error( + f"Message added error, db roll back - {str(e)}", + extra={ + "conversation_id": str(conversation_id), + "role": role, + "content_length": len(content), + }, + ) + self.db.rollback() + raise BusinessException( + f"Error adding message, conversation_id={conversation_id}", + code=BizCode.DB_ERROR + ) def get_messages( - self, - conversation_id: uuid.UUID, - limit: Optional[int] = None + self, + conversation_id: uuid.UUID, + limit: Optional[int] = None ) -> List[Message]: - """获取会话消息""" - stmt = select(Message).where( - Message.conversation_id == conversation_id - ).order_by(Message.created_at) + """ + Retrieve messages for a conversation. - if limit: - stmt = stmt.limit(limit) + Args: + conversation_id (uuid.UUID): Conversation UUID. + limit (Optional[int]): Optional maximum number of messages. - messages = list(self.db.scalars(stmt).all()) + Returns: + List[Message]: List of messages ordered by creation time. + """ + messages = self.message_repo.get_message_by_conversation_id( + conversation_id, + limit + ) return messages def get_conversation_history( - self, - conversation_id: uuid.UUID, - max_history: Optional[int] = None + self, + conversation_id: uuid.UUID, + max_history: Optional[int] = None ) -> List[dict]: - """获取会话历史消息 + """ + Retrieve historical conversation messages formatted as dictionaries. Args: - conversation_id: 会话ID - max_history: 最大历史消息数量 + conversation_id (uuid.UUID): Conversation UUID. + max_history (Optional[int]): Maximum number of messages to retrieve. Returns: - List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...] + List[dict]: List of message dictionaries with keys 'role' and 'content'. """ - messages = self.get_messages(conversation_id, limit=max_history) + messages = self.message_repo.get_message_by_conversation_id( + conversation_id, + limit=max_history + ) # 转换为字典格式 history = [ @@ -185,20 +295,25 @@ class ConversationService: return history def save_conversation_messages( - self, - conversation_id: uuid.UUID, - user_message: str, - assistant_message: str + self, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str ): - """保存会话消息(用户消息和助手回复)""" - # 添加用户消息 + """ + Save a pair of user and assistant messages to the conversation. + + Args: + conversation_id (uuid.UUID): Conversation UUID. + user_message (str): User's message content. + assistant_message (str): Assistant's response content. + """ self.add_message( conversation_id=conversation_id, role="user", content=user_message ) - # 添加助手消息 self.add_message( conversation_id=conversation_id, role="assistant", @@ -206,7 +321,7 @@ class ConversationService: ) logger.debug( - "保存会话消息成功", + "Saved conversation messages successfully", extra={ "conversation_id": str(conversation_id), "user_message_length": len(user_message), @@ -215,35 +330,59 @@ class ConversationService: ) def delete_conversation( - self, - conversation_id: uuid.UUID, - workspace_id: uuid.UUID + self, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID ): - """删除会话(软删除)""" - conversation = self.get_conversation(conversation_id, workspace_id) - conversation.is_active = False + """ + Soft delete a conversation. - self.db.commit() + Args: + conversation_id (uuid.UUID): Conversation UUID. + workspace_id (uuid.UUID): Workspace UUID for validation. + """ + try: + self.conversation_repo.soft_delete_conversation_by_conversation_id( + conversation_id, + workspace_id + ) + self.db.commit() - logger.info( - "删除会话成功", - extra={ - "conversation_id": str(conversation_id), - "workspace_id": str(workspace_id) - } - ) + logger.info( + "Soft deleted conversation successfully", + extra={ + "conversation_id": str(conversation_id), + "workspace_id": str(workspace_id) + } + ) + except Exception as e: + self.db.rollback() + logger.error( + f"Error deleting conversation, conversation_id={conversation_id} - {str(e)}", + ) + raise BusinessException("Error deleting conversation", code=BizCode.DB_ERROR) def create_or_get_conversation( - self, - app_id: uuid.UUID, - workspace_id: uuid.UUID, - is_draft: bool = False, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + is_draft: bool = False, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, ) -> Conversation: - """创建或获取会话""" + """ + Retrieve an existing conversation by ID or create a new one. - # 如果提供了 conversation_id,尝试获取现有会话 + Args: + app_id (uuid.UUID): Application ID. + workspace_id (uuid.UUID): Workspace ID. + is_draft (bool): Whether the conversation should be a draft. + conversation_id (Optional[uuid.UUID]): Optional conversation ID to retrieve. + user_id (Optional[str]): Optional user ID. + + Returns: + Conversation: Existing or newly created conversation. + """ if conversation_id: try: conversation = self.get_conversation( @@ -253,11 +392,14 @@ class ConversationService: # 验证会话是否属于该应用 if conversation.app_id != app_id: - raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) + raise BusinessException( + "Conversation does not belong to this app", + BizCode.INVALID_CONVERSATION + ) return conversation except ResourceNotFoundException: logger.warning( - "会话不存在,将创建新会话", + "Conversation not found. A new conversation will be created.", extra={"conversation_id": str(conversation_id)} ) @@ -270,15 +412,195 @@ class ConversationService: ) logger.info( - "为分享链接创建新会话" + "Created a new conversation for shared link usage", + extra={ + "conversation_id": str(conversation_id), + } ) return conversation -# ==================== 依赖注入函数 ==================== + async def get_conversation_detail( + self, + user: User, + conversation_id: uuid.UUID, + workspace_id: uuid.UUID, + language: str = "zh" + ) -> ConversationOut: + """ + Retrieve or generate the summary and theme of a conversation. + + This method first attempts to fetch the conversation detail from the repository. + If no detail exists or the conversation is outdated (>1 day), it generates a new + summary using the configured LLM model, stores it, and returns it. + + Args: + user (User): The user requesting the conversation summary. + conversation_id (UUID): Unique identifier of the conversation. + workspace_id (UUID): Identifier of the workspace where the conversation belongs. + language (str, optional): Language for the summary generation. Defaults to "zh". + + Returns: + ConversationOut: An object containing the conversation's theme, summary, + takeaways, and information score. + + Raises: + BusinessException: If the workspace model is not configured, the model does + not exist, API keys are missing, or the LLM output is invalid. + + Notes: + - If conversation details exist and are recent, they are returned directly. + - LLM generation uses system and user prompt templates from the filesystem. + - JSON repair is applied to ensure model outputs can be safely parsed. + - Commits the new conversation detail only if it is generated or outdated. + """ + logger.info(f"Fetching conversation detail for conversation_id={conversation_id}, workspace_id={workspace_id}") + + conversation_detail = self.conversation_repo.get_conversation_detail( + conversation_id=conversation_id, + ) + conversation = self.get_conversation( + conversation_id=conversation_id, + ) + if not conversation: + raise BusinessException("Conversation not found", BizCode.INVALID_CONVERSATION) + is_stable = ( + conversation.updated_at + and datetime.now() - conversation.updated_at > timedelta(days=1) + ) + if conversation_detail and is_stable: + logger.info(f"Conversation detail found in repository for conversation_id={conversation_id}") + return ConversationOut( + theme=conversation_detail.theme, + question=conversation_detail.question if conversation_detail.question else [], + summary=conversation_detail.summary, + takeaways=conversation_detail.takeaways, + info_score=conversation_detail.info_score, + ) + logger.info("Conversation detail not found, generating new summary using LLM") + configs = workspace_service.get_workspace_models_configs( + db=self.db, + workspace_id=workspace_id, + user=user + ) + model_id = configs.get('llm') + if not model_id: + logger.error(f"Workspace model configuration not found for workspace_id={workspace_id}") + raise BusinessException("Workspace model configuration not found. Please configure a model first.", code=BizCode.MODEL_NOT_FOUND) + config = ModelConfigService.get_model_by_id(db=self.db, model_id=model_id) + + if not config: + logger.error("Configured model not found for model_id={model_id}") + raise BusinessException("Configured model does not exist.", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + logger.error(f"Model API keys missing for model_id={model_id}", ) + raise BusinessException("Model configuration missing API keys.", BizCode.INVALID_PARAMETER) + + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base + ), + type=ModelType(model_type) + ) + + conversation_messages = self.get_conversation_history( + conversation_id=conversation_id, + max_history=30 + ) + + with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f: + system_prompt = f.read() + rendered_system_message = Template(system_prompt).render() + + with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f: + user_prompt = f.read() + rendered_user_message = Template(user_prompt).render( + language=language, + conversation=str(conversation_messages) + ) + + messages = [ + (RoleType.SYSTEM, rendered_system_message), + (RoleType.USER, rendered_user_message), + ] + logger.info(f"Invoking LLM for conversation_id={conversation_id}") + model_resp = await llm.ainvoke(messages) + try: + if isinstance(model_resp.content, str): + result = json_repair.repair_json(model_resp.content, return_objects=True) + elif isinstance(model_resp.content, list): + result = json_repair.repair_json(model_resp.content[0].get("text"), return_objects=True) + elif isinstance(model_resp.content, dict): + result = model_resp.content + else: + raise BusinessException("Unexpect model output", code=BizCode.LLM_ERROR) + except Exception as e: + logger.exception(f"Failed to parse LLM response for conversation_id={conversation_id}") + raise BusinessException("Failed to parse LLM response", code=BizCode.LLM_ERROR) from e + + summary = result.get('summary', "") + theme = result.get('theme', "") + question = result.get("question") or [] + takeaways = result.get("takeaways") or [] + info_score = result.get("info_score", 50) + + if not is_stable: + if not conversation_detail: + logger.info(f"Creating conversation detail in DB for conversation_id={conversation_id}") + conversation_detail = ConversationDetail( + conversation_id=conversation.id, + summary=summary, + theme=theme, + question=question, + takeaways=takeaways, + info_score=info_score + ) + self.conversation_repo.add_conversation_detail(conversation_detail) + else: + logger.info(f"Updating conversation detail in DB for conversation_id={conversation_id}") + conversation_detail.summary = summary + conversation_detail.theme = theme + conversation_detail.question = question + conversation_detail.takeaways = takeaways + conversation_detail.info_score = info_score + + self.db.commit() + self.db.refresh(conversation_detail) + + logger.info(f"Returning conversation summary for conversation_id={conversation_id}") + conversation_out = ConversationOut( + theme=theme, + question=question, + summary=summary, + takeaways=takeaways, + info_score=info_score + ) + return conversation_out + + +# ==================== Dependency Injection ==================== def get_conversation_service( db: Annotated[Session, Depends(get_db)] ) -> ConversationService: - """获取工作流服务(依赖注入)""" + """ + Dependency injection function to provide ConversationService instance. + + Args: + db (Session): Database session provided by FastAPI dependency. + + Returns: + ConversationService: Service instance. + """ return ConversationService(db) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index c0d2e3ff..cdbb213e 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,19 +10,22 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.models import AgentConfig, ModelApiKey, ModelConfig +from app.repositories.tool_repository import ToolRepository from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session +from app.services.tool_service import ToolService logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -291,16 +294,30 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search_config = agent_config.tools.get("web_search", {}) - web_search_enable = web_search_config.get("enabled", False) + tool_service = ToolService(self.db) - if web_search_enable: - logger.info("网络搜索已启用") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search_config) + # 从配置中获取启用的工具 + if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): + if hasattr(agent_config, 'tools') and agent_config.tools: + for tool_config in agent_config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, str(workspace_id))) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): + web_tools = agent_config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) tools.append(search_tool) logger.debug( @@ -454,7 +471,8 @@ class DraftRunService: storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, web_search: bool = True, # 布尔类型默认值 - memory: bool = True # 布尔类型默认值 + memory: bool = True, # 布尔类型默认值 + sub_agent: bool = False # 是否是作为子Agent运行 ) -> AsyncGenerator[str, None]: """执行试运行(流式返回,使用 LangChain Agent) @@ -502,16 +520,29 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search_config = agent_config.tools.get("web_search", {}) - web_search_enable = web_search_config.get("enabled", False) + tool_service = ToolService(self.db) - if web_search_enable: - logger.info("网络搜索已启用") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search_config) + # 从配置中获取启用的工具 + if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): + for tool_config in agent_config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, str(workspace_id))) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): + web_tools = agent_config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) tools.append(search_tool) logger.debug( @@ -521,6 +552,7 @@ class DraftRunService: } ) + # 添加知识库检索工具 if agent_config.knowledge_retrieval: kb_config = agent_config.knowledge_retrieval @@ -619,7 +651,7 @@ class DraftRunService: elapsed_time = time.time() - start_time # 10. 保存会话消息 - if agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py new file mode 100644 index 00000000..be32b864 --- /dev/null +++ b/api/app/services/handoffs_service.py @@ -0,0 +1,609 @@ +"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作""" +import json +import uuid +from typing import List, Dict, Any, Optional, AsyncGenerator +from typing_extensions import TypedDict + +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langgraph.graph import StateGraph, START, END +from langgraph.types import Command +from langgraph.checkpoint.memory import MemorySaver +from langchain_core.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType +from app.services.model_service import ModelApiKeyService + +logger = get_business_logger() + + +# ==================== 状态定义 ==================== + +class HandoffState(TypedDict): + """Handoff 状态""" + messages: List[BaseMessage] + active_agent: Optional[str] + + +# ==================== 工具输入模型 ==================== + +class TransferInput(BaseModel): + """转移工具的输入参数""" + reason: str = Field(description="转移原因") + + +# ==================== 工具创建 ==================== + +def create_transfer_tool(target_agent: str, description: str): + """动态创建转移工具 + + Args: + target_agent: 目标 Agent 名称 + description: 工具描述 + + Returns: + 转移工具函数 + """ + tool_name = f"transfer_to_{target_agent}" + + @tool(tool_name, args_schema=TransferInput) + def transfer_tool(reason: str) -> Command: + """动态生成的转移工具""" + return Command( + goto=target_agent, + update={"active_agent": target_agent}, + ) + + transfer_tool.__doc__ = description + transfer_tool.description = description + return transfer_tool + + +def create_tools_for_agent(agent_name: str, configs: Dict) -> List: + """根据 Agent 配置动态创建其可用的转移工具 + + Args: + agent_name: 当前 Agent 名称 + configs: Agent 配置字典 + + Returns: + 该 Agent 可用的工具列表 + """ + config = configs.get(agent_name, {}) + can_transfer_to = config.get("can_transfer_to", []) + + tools = [] + for target_agent in can_transfer_to: + target_config = configs.get(target_agent, {}) + description = target_config.get("description", f"转移到 {target_agent}") + tools.append(create_transfer_tool(target_agent, description)) + + return tools + + +# ==================== Agent 节点创建 ==================== + +def create_agent_node(agent_name: str, system_prompt: str, tools: List, + model_config: RedBearModelConfig): + """创建 Agent 节点(非流式)""" + llm = RedBearLLM(model_config, type=ModelType.CHAT) + + # 绑定工具 + if tools: + llm = llm.bind_tools(tools) + + async def agent_node(state: HandoffState) -> Dict[str, Any]: + """Agent 节点执行函数""" + logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}") + + messages = state.get("messages", []) + full_messages = [{"role": "system", "content": system_prompt}] + messages + + response = await llm.ainvoke(full_messages) + + # 检查工具调用 + if hasattr(response, 'tool_calls') and response.tool_calls: + tool_call = response.tool_calls[0] + tool_name = tool_call["name"] if isinstance(tool_call, dict) else tool_call.name + tool_args = tool_call["args"] if isinstance(tool_call, dict) else tool_call.args + + if isinstance(tool_args, str): + try: + tool_args = json.loads(tool_args) + except (json.JSONDecodeError, ValueError): + tool_args = {} + + if not tool_args.get("reason"): + tool_args["reason"] = "用户请求转移" + + for t in tools: + if t.name == tool_name: + logger.info(f"Agent {agent_name} 调用工具: {tool_name}") + result = t.invoke(tool_args) + if isinstance(result, Command): + return result + + return {"messages": [response]} + + return agent_node + + +def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List, + model_config: RedBearModelConfig): + """创建支持流式输出的 Agent 节点""" + llm = RedBearLLM(model_config, type=ModelType.CHAT) + + # 绑定工具 + if tools: + llm = llm.bind_tools(tools) + + async def agent_node(state: HandoffState): + """Agent 节点执行函数(流式)""" + logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}") + + messages = state.get("messages", []) + full_messages = [{"role": "system", "content": system_prompt}] + messages + + full_content = "" + collected_tool_calls = {} + + async for chunk in llm.astream(full_messages): + if hasattr(chunk, 'content') and chunk.content: + full_content += chunk.content + + # 收集工具调用 + if hasattr(chunk, 'tool_calls') and chunk.tool_calls: + for tc in chunk.tool_calls: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, 'id', "0") + tc_id = tc_id or "0" + if tc_id not in collected_tool_calls: + collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""} + + tc_name = tc.get("name") if isinstance(tc, dict) else getattr(tc, 'name', None) + tc_args = tc.get("args") if isinstance(tc, dict) else getattr(tc, 'args', None) + + if tc_name: + collected_tool_calls[tc_id]["name"] = tc_name + if tc_args: + if isinstance(tc_args, dict): + collected_tool_calls[tc_id]["args"] = tc_args + elif isinstance(tc_args, str): + if isinstance(collected_tool_calls[tc_id]["args"], str): + collected_tool_calls[tc_id]["args"] += tc_args + + # 处理 tool_call_chunks + if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: + for tc_chunk in chunk.tool_call_chunks: + idx = str(tc_chunk.get("index", 0) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'index', 0)) + if idx not in collected_tool_calls: + tc_id = tc_chunk.get("id", idx) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', idx) + collected_tool_calls[idx] = {"id": tc_id, "name": "", "args": ""} + + tc_id = tc_chunk.get("id") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', None) + tc_name = tc_chunk.get("name") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'name', None) + tc_args = tc_chunk.get("args") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'args', None) + + if tc_id: + collected_tool_calls[idx]["id"] = tc_id + if tc_name: + collected_tool_calls[idx]["name"] = tc_name + if tc_args: + if isinstance(collected_tool_calls[idx]["args"], str): + collected_tool_calls[idx]["args"] += tc_args + + # 解析工具调用 + tool_calls_list = list(collected_tool_calls.values()) + for tc in tool_calls_list: + if isinstance(tc.get("args"), str) and tc["args"]: + try: + tc["args"] = json.loads(tc["args"]) + except (json.JSONDecodeError, ValueError): + tc["args"] = {} + elif not tc.get("args"): + tc["args"] = {} + + # 执行工具调用 + if tool_calls_list and tool_calls_list[0].get("name"): + tool_call = tool_calls_list[0] + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + + if not tool_args.get("reason"): + tool_args["reason"] = "用户请求转移" + + for t in tools: + if t.name == tool_name: + logger.info(f"Agent {agent_name} 调用工具: {tool_name}") + result = t.invoke(tool_args) + if isinstance(result, Command): + return result + + return {"messages": [AIMessage(content=full_content)]} + + return agent_node + + +# ==================== 路由函数 ==================== + +def create_route_initial(default_agent: str): + """创建初始路由函数""" + def route_initial(state: HandoffState) -> str: + active = state.get("active_agent") + if active: + return active + return default_agent + return route_initial + + +def route_after_agent(state: HandoffState) -> str: + """Agent 执行后的路由""" + messages = state.get("messages", []) + if messages: + last_msg = messages[-1] + if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None): + return END + return state.get("active_agent", END) + + +# ==================== 配置转换 ==================== + +def convert_multi_agent_config_to_handoffs( + multi_agent_config: Dict, + db: Session +) -> Dict[str, Dict]: + """将 multi_agent_config 转换为 handoffs 配置格式 + + Args: + multi_agent_config: 数据库中的多 Agent 配置 + db: 数据库会话 + + Returns: + agent_configs 字典,每个 Agent 包含自己的 model_config + """ + from app.models import AppRelease, App + + sub_agents = multi_agent_config.get("sub_agents", []) + agent_configs = {} + agent_names = [] + + # 遍历子 Agent,构建配置 + for sub_agent in sub_agents: + agent_id = sub_agent.get("agent_id") # 可能是 release_id 或 app_id + agent_name = sub_agent.get("name", f"agent_{agent_id[:8] if agent_id else 'unknown'}") + # 使用安全的 agent name(去除特殊字符) + safe_name = agent_name.replace(" ", "_").replace("-", "_").lower() + agent_names.append(safe_name) + + # 从 AppRelease 获取 Agent 的系统提示词和模型配置 + system_prompt = f"你是 {agent_name}。" + capabilities = sub_agent.get("capabilities", []) + model_config = None + release = None + + if agent_id: + try: + agent_id_uuid = uuid.UUID(agent_id) if isinstance(agent_id, str) else agent_id + + # 先尝试作为 release_id 查询 + release = db.get(AppRelease, agent_id_uuid) + + # 如果找不到,尝试作为 app_id 查询,获取 current_release + if not release: + app = db.get(App, agent_id_uuid) + if app and app.current_release_id: + release = db.get(AppRelease, app.current_release_id) + + if release: + # 从 release.config 获取 system_prompt + if release.config: + config_data = release.config + release_system_prompt = config_data.get("system_prompt") + if release_system_prompt: + system_prompt = release_system_prompt + + # 获取该 Agent 的模型配置 + if release.default_model_config_id: + model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id) + if model_api_key: + model_config = RedBearModelConfig( + model_name=model_api_key.model_name, + provider=model_api_key.provider, + api_key=model_api_key.api_key, + base_url=model_api_key.api_base, + extra_params={ + "temperature": 0.7, + "max_tokens": 2000, + "streaming": True + } + ) + logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}") + else: + logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}") + else: + logger.warning(f"Agent {agent_name} 没有配置 default_model_config_id") + else: + logger.warning(f"Agent {agent_name} 找不到发布版本: agent_id={agent_id}") + except Exception as e: + logger.warning(f"获取 Agent {agent_name} 配置失败: {str(e)}") + + # 如果有 capabilities,添加到系统提示词 + if capabilities: + if not system_prompt.endswith("。"): + system_prompt += "。" + system_prompt += f" 你的专长是: {', '.join(capabilities)}。" + + agent_configs[safe_name] = { + "agent_id": agent_id, + "name": agent_name, + "description": f"转移到 {agent_name}。{sub_agent.get('role') or ''}", + "system_prompt": system_prompt, + "capabilities": capabilities, + "model_config": model_config, # 每个 Agent 自己的模型配置 + "can_transfer_to": [] # 稍后填充 + } + + # 设置每个 Agent 可以转移到的其他 Agent + for safe_name in agent_names: + agent_configs[safe_name]["can_transfer_to"] = [ + name for name in agent_names if name != safe_name + ] + # 更新系统提示词,添加转移说明 + other_agents = agent_configs[safe_name]["can_transfer_to"] + if other_agents: + transfer_instructions = "\n如果用户的问题不在你的专长范围内,可以使用以下工具转移到其他 Agent:" + for other_name in other_agents: + other_config = agent_configs[other_name] + transfer_instructions += f"\n- transfer_to_{other_name}: {other_config['description']}" + agent_configs[safe_name]["system_prompt"] += transfer_instructions + + return agent_configs + + +# ==================== Handoffs 服务类 ==================== + +class HandoffsService: + """Handoffs 服务 - 管理多 Agent 协作""" + + def __init__( + self, + agent_configs: Dict[str, Dict], + streaming: bool = True + ): + """初始化 Handoffs 服务 + + Args: + agent_configs: Agent 配置字典,每个 Agent 包含自己的 model_config + streaming: 是否启用流式输出 + """ + self.agent_configs = agent_configs + self.streaming = streaming + self._graph = None + + # 验证每个 Agent 都有模型配置 + for agent_name, config in agent_configs.items(): + if not config.get("model_config"): + raise ValueError(f"Agent {agent_name} 没有配置模型") + + logger.info(f"HandoffsService 初始化, agents: {list(self.agent_configs.keys())}") + + def _build_graph(self): + """构建 LangGraph 图""" + builder = StateGraph(HandoffState) + agent_names = list(self.agent_configs.keys()) + + if not agent_names: + + raise ValueError("至少需要一个 Agent 配置") + + for agent_name in agent_names: + config = self.agent_configs[agent_name] + tools = create_tools_for_agent(agent_name, self.agent_configs) + + # 使用每个 Agent 自己的模型配置 + agent_model_config = config.get("model_config") + + if self.streaming: + agent_node = create_streaming_agent_node( + agent_name=agent_name, + system_prompt=config.get("system_prompt", f"你是 {agent_name}"), + tools=tools, + model_config=agent_model_config + ) + else: + agent_node = create_agent_node( + agent_name=agent_name, + system_prompt=config.get("system_prompt", f"你是 {agent_name}"), + tools=tools, + model_config=agent_model_config + ) + builder.add_node(agent_name, agent_node) + + # 添加边 + default_agent = agent_names[0] + builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names) + + for agent_name in agent_names: + builder.add_conditional_edges(agent_name, route_after_agent, agent_names + [END]) + + memory = MemorySaver() + return builder.compile(checkpointer=memory) + + @property + def graph(self): + """获取图实例(懒加载)""" + if self._graph is None: + self._graph = self._build_graph() + return self._graph + + def reset(self): + """重置图实例""" + self._graph = None + logger.info("HandoffsService 图已重置") + + async def chat( + self, + message: str, + conversation_id: str = None + ) -> Dict[str, Any]: + """非流式聊天""" + conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" + config = {"configurable": {"thread_id": str(conversation_id)}} + + logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...") + + result = await self.graph.ainvoke({ + "messages": [HumanMessage(content=message)] + }, config=config) + + # 提取响应 + response_content = "" + for msg in result.get("messages", []): + if isinstance(msg, AIMessage): + response_content = msg.content + break + + return { + "conversation_id": str(conversation_id), + "active_agent": result.get("active_agent"), + "response": response_content, + "message_count": len(result.get("messages", [])) + } + + async def chat_stream( + self, + message: str, + conversation_id: str = None + ) -> AsyncGenerator[str, None]: + """流式聊天""" + conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" + config = {"configurable": {"thread_id": str(conversation_id)}} + + logger.info(f"Handoffs stream chat: conversation_id={conversation_id}, message={message[:50]}...") + + # 发送开始事件 + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + + current_agent = None + + try: + async for event in self.graph.astream_events( + {"messages": [HumanMessage(content=message)]}, + config=config, + version="v2" + ): + kind = event["event"] + + # 捕获节点开始(Agent 切换) + if kind == "on_chain_start": + node_name = event.get("name", "") + if node_name in self.agent_configs: + if current_agent != node_name: + current_agent = node_name + agent_display_name = self.agent_configs[node_name].get("name", node_name) + yield f"event: agent\ndata: {json.dumps({'agent': node_name, 'agent_name': agent_display_name}, ensure_ascii=False)}\n\n" + + # 捕获 LLM 流式输出 + elif kind == "on_chat_model_stream": + content = event["data"]["chunk"].content + if content: + yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n" + + # 捕获工具调用(Handoff) + elif kind == "on_tool_start": + tool_name = event.get("name", "") + if tool_name.startswith("transfer_to_"): + target_agent = tool_name.replace("transfer_to_", "") + target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent) + yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name}, ensure_ascii=False)}\n\n" + + # 发送结束事件 + yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n" + + except Exception as e: + logger.error(f"Handoffs stream error: {str(e)}") + yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + + def get_agents(self) -> List[Dict[str, Any]]: + """获取可用的 Agent 列表""" + agents = [] + for name, config in self.agent_configs.items(): + agents.append({ + "id": name, + "name": config.get("name", name), + "description": config.get("description", ""), + "capabilities": config.get("capabilities", []), + "can_transfer_to": config.get("can_transfer_to", []) + }) + return agents + + +# ==================== 服务工厂 ==================== + +# 缓存服务实例(按 app_id) +_service_cache: Dict[str, HandoffsService] = {} + + +def get_handoffs_service_for_app( + app_id: uuid.UUID, + db: Session, + streaming: bool = True +) -> HandoffsService: + """根据 app_id 获取 Handoffs 服务实例 + + Args: + app_id: 应用 ID + db: 数据库会话 + streaming: 是否流式 + + Returns: + HandoffsService 实例 + """ + from app.services.multi_agent_service import MultiAgentService + + cache_key = f"{app_id}_{streaming}" + + # 检查缓存 + if cache_key in _service_cache: + return _service_cache[cache_key] + + # 获取多 Agent 配置 + multi_agent_service = MultiAgentService(db) + multi_agent_config = multi_agent_service.get_multi_agent_configs(app_id) + + if not multi_agent_config: + raise ValueError(f"应用 {app_id} 没有多 Agent 配置") + + # 转换配置(每个 Agent 包含自己的 model_config) + agent_configs = convert_multi_agent_config_to_handoffs(multi_agent_config, db) + + if not agent_configs: + raise ValueError(f"应用 {app_id} 没有配置子 Agent") + + # 创建服务 + service = HandoffsService(agent_configs, streaming) + + # 缓存 + _service_cache[cache_key] = service + + return service + + +def reset_handoffs_service_cache(app_id: uuid.UUID = None): + """重置服务缓存 + + Args: + app_id: 应用 ID,如果为 None 则清除所有缓存 + """ + global _service_cache + + if app_id: + keys_to_remove = [k for k in _service_cache if k.startswith(str(app_id))] + for key in keys_to_remove: + del _service_cache[key] + else: + _service_cache = {} + + logger.info(f"Handoffs 服务缓存已重置: app_id={app_id}") diff --git a/api/app/services/implicit_memory_service.py b/api/app/services/implicit_memory_service.py new file mode 100644 index 00000000..8155b7a1 --- /dev/null +++ b/api/app/services/implicit_memory_service.py @@ -0,0 +1,375 @@ +""" +Implicit Memory Service + +Main service orchestrating all implicit memory operations. This service coordinates +profile building, data extraction, and provides high-level methods for analyzing +user profiles from memory summaries. +""" + +import logging +from datetime import datetime +from typing import List, Optional + +from app.core.memory.analytics.implicit_memory.analyzers.dimension_analyzer import ( + DimensionAnalyzer, +) +from app.core.memory.analytics.implicit_memory.analyzers.interest_analyzer import ( + InterestAnalyzer, +) +from app.core.memory.analytics.implicit_memory.analyzers.preference_analyzer import ( + PreferenceAnalyzer, +) +from app.core.memory.analytics.implicit_memory.data_source import MemoryDataSource +from app.core.memory.analytics.implicit_memory.habit_detector import HabitDetector +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.implicit_memory_schema import ( + BehaviorHabit, + DateRange, + DimensionPortrait, + FrequencyPattern, + InterestAreaDistribution, + PreferenceTag, + TimeRange, + UserMemorySummary, +) +from app.schemas.memory_config_schema import MemoryConfig +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +class ImplicitMemoryService: + """Main service for implicit memory operations.""" + + def __init__( + self, + db: Session, + end_user_id: str + ): + """Initialize the implicit memory service. + + Args: + db: Database session + end_user_id: End user ID to get connected memory configuration + """ + self.db = db + self.end_user_id = end_user_id + + # Get connected memory configuration for the user + self.memory_config = self._get_user_memory_config() + + # Extract LLM model ID from memory config + llm_model_id = str(self.memory_config.llm_model_id) if self.memory_config.llm_model_id else None + + # Initialize Neo4j connector + self.neo4j_connector = Neo4jConnector() + + # Initialize core components with LLM model ID + self.data_source = MemoryDataSource(db, self.neo4j_connector) + self.preference_analyzer = PreferenceAnalyzer(db, llm_model_id) + self.dimension_analyzer = DimensionAnalyzer(db, llm_model_id) + self.interest_analyzer = InterestAnalyzer(db, llm_model_id) + self.habit_detector = HabitDetector(db, llm_model_id) + + logger.info(f"ImplicitMemoryService initialized for end_user: {end_user_id}") + + def _get_user_memory_config(self) -> MemoryConfig: + """Get memory configuration for the connected end user. + + Returns: + MemoryConfig: User's connected memory configuration + + Raises: + ValueError: If no memory configuration found for user + """ + try: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + + # Get user's connected config + connected_config = get_end_user_connected_config(self.end_user_id, self.db) + config_id = connected_config.get("memory_config_id") + + if config_id is None: + raise ValueError(f"No memory configuration found for end_user: {self.end_user_id}") + + # Load the memory configuration + config_service = MemoryConfigService(self.db) + memory_config = config_service.load_memory_config(config_id) + + logger.info(f"Loaded memory config {config_id} for end_user: {self.end_user_id}") + return memory_config + + except Exception as e: + logger.error(f"Failed to get memory config for end_user {self.end_user_id}: {e}") + raise ValueError(f"Unable to get memory configuration for end_user {self.end_user_id}: {e}") + + async def extract_user_summaries( + self, + user_id: str, + time_range: Optional[TimeRange] = None, + limit: Optional[int] = None + ) -> List[UserMemorySummary]: + """Extract user-specific memory summaries. + + Args: + user_id: Target user ID + time_range: Optional time range to filter summaries + limit: Optional limit on number of summaries + + Returns: + List of user-specific memory summaries + """ + logger.info(f"Extracting user summaries for user {user_id}") + + try: + summaries = await self.data_source.get_user_summaries( + user_id=user_id, + time_range=time_range, + limit=limit or 1000 + ) + + logger.info(f"Extracted {len(summaries)} summaries for user {user_id}") + return summaries + + except Exception as e: + logger.error(f"Failed to extract user summaries for user {user_id}: {e}") + raise + + async def get_preference_tags( + self, + user_id: str, + confidence_threshold: float = 0.5, + tag_category: Optional[str] = None, + date_range: Optional[DateRange] = None + ) -> List[PreferenceTag]: + """Retrieve user preference tags with filtering. + + Args: + user_id: Target user ID + confidence_threshold: Minimum confidence score for tags + tag_category: Optional category filter + date_range: Optional date range filter + + Returns: + List of filtered preference tags + """ + logger.info(f"Getting preference tags for user {user_id}") + + try: + # Get user summaries for analysis + time_range = None + if date_range: + time_range = TimeRange( + start_date=date_range.start_date or datetime.min, + end_date=date_range.end_date or datetime.now() + ) + + user_summaries = await self.extract_user_summaries( + user_id=user_id, + time_range=time_range + ) + + if not user_summaries: + logger.warning(f"No summaries found for user {user_id}") + return [] + + # Analyze preferences + preference_tags = await self.preference_analyzer.analyze_preferences( + user_id=user_id, + user_summaries=user_summaries + ) + + # Apply filters + filtered_tags = [] + for tag in preference_tags: + # Filter by confidence threshold + if tag.confidence_score < confidence_threshold: + continue + + # Filter by category if specified + if tag_category and tag.category != tag_category: + continue + + # Filter by date range if specified + if date_range: + if date_range.start_date and tag.created_at < date_range.start_date: + continue + if date_range.end_date and tag.created_at > date_range.end_date: + continue + + filtered_tags.append(tag) + + # Sort by confidence score and recency + filtered_tags.sort( + key=lambda x: (x.confidence_score, x.updated_at), + reverse=True + ) + + logger.info(f"Retrieved {len(filtered_tags)} preference tags for user {user_id}") + return filtered_tags + + except Exception as e: + logger.error(f"Failed to get preference tags for user {user_id}: {e}") + raise + + async def get_dimension_portrait( + self, + user_id: str, + include_history: bool = False + ) -> DimensionPortrait: + """Get user's four-dimension personality portrait. + + Args: + user_id: Target user ID + include_history: Whether to include historical trends + + Returns: + User's dimension portrait + """ + logger.info(f"Getting dimension portrait for user {user_id}") + + try: + # Get user summaries + user_summaries = await self.extract_user_summaries(user_id=user_id) + + if not user_summaries: + logger.warning(f"No summaries found for user {user_id}") + return self.dimension_analyzer._create_empty_portrait(user_id) + + # Analyze dimensions + dimension_portrait = await self.dimension_analyzer.analyze_dimensions( + user_id=user_id, + user_summaries=user_summaries + ) + + # Include historical trends if requested + if include_history: + # In a full implementation, this would retrieve historical data + # For now, we'll leave historical_trends as None + pass + + logger.info(f"Retrieved dimension portrait for user {user_id}") + return dimension_portrait + + except Exception as e: + logger.error(f"Failed to get dimension portrait for user {user_id}: {e}") + raise + + async def get_interest_area_distribution( + self, + user_id: str, + include_trends: bool = False + ) -> InterestAreaDistribution: + """Get user's interest area distribution across four areas. + + Args: + user_id: Target user ID + include_trends: Whether to include trending information + + Returns: + User's interest area distribution + """ + logger.info(f"Getting interest area distribution for user {user_id}") + + try: + # Get user summaries + user_summaries = await self.extract_user_summaries(user_id=user_id) + + if not user_summaries: + logger.warning(f"No summaries found for user {user_id}") + return self.interest_analyzer._create_empty_distribution(user_id) + + # Analyze interests + interest_distribution = await self.interest_analyzer.analyze_interests( + user_id=user_id, + user_summaries=user_summaries + ) + + # Include trends if requested + if include_trends: + # In a full implementation, this would calculate trending directions + # For now, we'll leave trending_direction as None for each category + pass + + logger.info(f"Retrieved interest area distribution for user {user_id}") + return interest_distribution + + except Exception as e: + logger.error(f"Failed to get interest area distribution for user {user_id}: {e}") + raise + + async def get_behavior_habits( + self, + user_id: str, + confidence_level: Optional[int] = None, + frequency_pattern: Optional[str] = None, + time_period: Optional[str] = None + ) -> List[BehaviorHabit]: + """Get user's behavioral habits with filtering. + + Args: + user_id: Target user ID + confidence_level: Optional confidence level filter (0-100) + frequency_pattern: Optional frequency pattern filter + time_period: Optional time period filter ("current", "past") + + Returns: + List of filtered behavioral habits + """ + logger.info(f"Getting behavior habits for user {user_id}") + + try: + # Get user summaries + user_summaries = await self.extract_user_summaries(user_id=user_id) + + if not user_summaries: + logger.warning(f"No summaries found for user {user_id}") + return [] + + # Detect habits + behavior_habits = await self.habit_detector.detect_habits( + user_id=user_id, + user_summaries=user_summaries + ) + + # Apply filters + filtered_habits = [] + for habit in behavior_habits: + # Filter by confidence level + if confidence_level is not None: + if habit.confidence_level < confidence_level: + continue + + # Filter by frequency pattern + if frequency_pattern: + try: + target_frequency = FrequencyPattern(frequency_pattern.lower()) + if habit.frequency_pattern != target_frequency: + continue + except ValueError: + logger.warning(f"Invalid frequency pattern: {frequency_pattern}") + continue + + # Filter by time period + if time_period: + if time_period.lower() == "current" and not habit.is_current: + continue + elif time_period.lower() == "past" and habit.is_current: + continue + + filtered_habits.append(habit) + + # Sort by confidence level and recency + filtered_habits.sort( + key=lambda x: (x.confidence_level, x.last_observed), + reverse=True + ) + + logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user {user_id}") + return filtered_habits + + except Exception as e: + logger.error(f"Failed to get behavior habits for user {user_id}: {e}") + raise + \ No newline at end of file diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8193da8a..d44408fe 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -4,6 +4,7 @@ Memory Agent Service Handles business logic for memory agent operations including read/write services, health checks, and message type classification. """ +import datetime import json import os import re @@ -24,6 +25,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType +from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.services.memory_config_service import MemoryConfigService @@ -393,7 +395,7 @@ class MemoryAgentService: import time start_time = time.time() - + ori_message=message # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -406,15 +408,15 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + logger.info(f"Read operation for group {group_id} with config_id {config_id}") - + # 导入审计日志记录器 try: from app.core.memory.utils.log.audit_logger import audit_logger except ImportError: audit_logger = None - + # Get group lock to prevent concurrent processing group_lock = self.get_group_lock(group_id) @@ -430,7 +432,7 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time @@ -442,9 +444,9 @@ class MemoryAgentService: duration=duration, error=error_msg ) - + raise ValueError(error_msg) - + # Step 2: Prepare history history.append({"role": "user", "content": message}) logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") @@ -452,7 +454,7 @@ class MemoryAgentService: # Step 3: Initialize MCP client and execute read workflow mcp_config = get_mcp_server_config() client = MultiServerMCPClient(mcp_config) - + async with client.session('data_flow') as session: logger.debug("Connected to MCP Server: data_flow") tools = await load_mcp_tools(session) @@ -475,7 +477,7 @@ class MemoryAgentService: # Capture any errors from the state if event.get('errors'): workflow_errors.extend(event.get('errors', [])) - + for msg in messages: msg_content = msg.content msg_role = msg.__class__.__name__.lower().replace("message", "") @@ -483,7 +485,7 @@ class MemoryAgentService: "role": msg_role, "content": msg_content }) - + # Extract intermediate outputs if hasattr(msg, 'content'): try: @@ -496,7 +498,7 @@ class MemoryAgentService: break else: continue # No text block found - + # Try to parse content as JSON if isinstance(content_to_parse, str): try: @@ -506,16 +508,16 @@ class MemoryAgentService: if '_intermediate' in parsed: intermediate_data = parsed['_intermediate'] output_key = self._create_intermediate_key(intermediate_data) - + if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - + # Check for multiple intermediate outputs (from Retrieve) if '_intermediates' in parsed: for intermediate_data in parsed['_intermediates']: output_key = self._create_intermediate_key(intermediate_data) - + if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) @@ -523,7 +525,7 @@ class MemoryAgentService: pass except Exception as e: logger.debug(f"Failed to extract intermediate output: {e}") - + workflow_duration = time.time() - start logger.info(f"Read graph workflow completed in {workflow_duration}s") @@ -532,7 +534,7 @@ class MemoryAgentService: for messages in outputs: if messages['role'] == 'tool': message = messages['content'] - + # Handle MCP content format: [{'type': 'text', 'text': '...'}] if isinstance(message, list): # Extract text from MCP content blocks @@ -542,7 +544,7 @@ class MemoryAgentService: break else: continue # No text block found - + try: parsed = json.loads(message) if isinstance(message, str) else message if isinstance(parsed, dict): @@ -552,15 +554,15 @@ class MemoryAgentService: final_answer = summary_result except (json.JSONDecodeError, ValueError): pass - + # 记录成功的操作 total_duration = time.time() - start_time - + # Check for workflow errors if workflow_errors: error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) logger.warning(f"Read workflow completed with errors: {error_details}") - + if audit_logger: audit_logger.log_operation( operation="READ", @@ -577,11 +579,11 @@ class MemoryAgentService: "errors": workflow_errors } ) - + # Raise error if no answer was produced if not final_answer: raise ValueError(f"Read workflow failed: {error_details}") - + if audit_logger and not workflow_errors: audit_logger.log_operation( operation="READ", @@ -596,7 +598,31 @@ class MemoryAgentService: "has_answer": bool(final_answer) } ) - + retrieved_content=[] + repo = ShortTermMemoryRepository(db) + if str(search_switch)!="2": + for intermediate in intermediate_outputs: + intermediate_type=intermediate['type'] + if intermediate_type=="search_result": + query=intermediate['query'] + raw_results=intermediate['raw_results'] + reranked_results=raw_results.get('reranked_results',[]) + statements=[statement['statement'] for statement in reranked_results.get('statements', [])] + statements=list(set(statements)) + retrieved_content.append({query:statements}) + if '信息不足,无法回答' in str(final_answer) or retrieved_content!=[]: + # 使用 upsert 方法 + repo.upsert( + end_user_id=group_id, # 确保这个变量在作用域内 + messages=ori_message, + aimessages=final_answer, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + print("写入成功") + + + return { "answer": final_answer, "intermediate_outputs": intermediate_outputs diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py new file mode 100644 index 00000000..2410bff2 --- /dev/null +++ b/api/app/services/memory_entity_relationship_service.py @@ -0,0 +1,676 @@ + +from app.repositories.neo4j.cypher_queries import ( +Memory_Timeline_ExtractedEntity, +Memory_Timeline_MemorySummary, +Memory_Timeline_Statement, +Memory_Space_Emotion_Statement, +Memory_Space_Emotion_MemorySummary, +Memory_Space_Emotion_ExtractedEntity, +Memory_Space_Associative,Memory_Space_User,Memory_Space_Entity +) +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from typing import Dict, List, Any, Optional +import logging +from neo4j.time import DateTime as Neo4jDateTime +import json +from datetime import datetime + +logger = logging.getLogger(__name__) + +class MemoryEntityService: + def __init__(self, id: str, table: str): + self.id = id + self.table = table + self.connector = Neo4jConnector() + async def get_timeline_memories_server(self): + """ + 获取时间线记忆数据 + + Args: + id: 节点ID + table: 节点类型/标签 + + Returns: + Dict包含: + - success: 是否成功 + - data: 时间线数据列表 + - total: 数据总数 + - error: 错误信息(如果有) + + 根据不同标签返回相应字段: + - MemorySummary: content字段 + - Statement: statement字段 + - ExtractedEntity: name字段 + """ + try: + logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}") + + # 根据表类型选择查询 + if self.table == 'Statement': + # Statement只需要输入ID,使用简化查询 + results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id) + elif self.table == 'ExtractedEntity': + # ExtractedEntity类型查询 + results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id) + else: + # MemorySummary类型查询 + results = await self.connector.execute_query(Memory_Timeline_MemorySummary, id=self.id) + + # 记录查询结果的类型和内容用于调试 + logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}") + + # 处理查询结果 + timeline_data = self._process_timeline_results(results) + + logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条") + + return timeline_data + + except Exception as e: + logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True) + return str(e) + def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 处理时间线查询结果 + + Args: + results: Neo4j查询结果 + + Returns: + 处理后的时间线数据字典 + """ + # 检查results是否为空或不是列表 + if not results or not isinstance(results, list): + logger.warning(f"时间线查询结果为空或格式不正确: {type(results)}") + return { + "MemorySummary": [], + "Statement": [], + "ExtractedEntity": [], + "timelines_memory": [] + } + + memory_summary_list = [] + statement_list = [] + extracted_entity_list = [] + + for data in results: + # 检查data是否为字典类型 + if not isinstance(data, dict): + logger.warning(f"跳过非字典类型的记录: {type(data)} - {data}") + continue + + # 处理MemorySummary + summary = data.get('MemorySummary') + if summary is not None: + processed_summary = self._process_field_value(summary, "MemorySummary") + memory_summary_list.extend(processed_summary) + + # 处理Statement + statement = data.get('statement') + if statement is not None: + processed_statement = self._process_field_value(statement, "Statement") + statement_list.extend(processed_statement) + + # 处理ExtractedEntity + extracted_entity = data.get('ExtractedEntity') + if extracted_entity is not None: + processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity") + extracted_entity_list.extend(processed_entity) + + # 去重 - 现在处理的是字典列表,需要更智能的去重 + memory_summary_list = self._deduplicate_dict_list(memory_summary_list) + statement_list = self._deduplicate_dict_list(statement_list) + extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list) + + # 合并所有数据并处理相同text的合并 + all_timeline_data = memory_summary_list + statement_list + extracted_entity_list + all_timeline_data = self._merge_same_text_items(all_timeline_data) + + result = { + "MemorySummary": memory_summary_list, + "Statement": statement_list, + "ExtractedEntity": extracted_entity_list, + "timelines_memory": all_timeline_data + } + + logger.info(f"时间线数据处理完成: MemorySummary={len(memory_summary_list)}, Statement={len(statement_list)}, ExtractedEntity={len(extracted_entity_list)}") + + return result + + def _deduplicate_dict_list(self, dict_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 对字典列表进行去重 + + Args: + dict_list: 字典列表 + + Returns: + 去重后的字典列表 + """ + seen = set() + result = [] + + for item in dict_list: + # 使用text作为去重的键 + text = item.get('text', '') + if text and text not in seen: + seen.add(text) + result.append(item) + + return result + + def _merge_same_text_items(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 合并具有相同text的项目,合并type字段,保留一个时间 + + Args: + items: 项目列表 + + Returns: + 合并后的项目列表 + """ + text_groups = {} + + # 按text分组 + for item in items: + text = item.get('text', '') + if not text: + continue + + if text not in text_groups: + text_groups[text] = { + 'text': text, + 'types': set(), + 'created_at': item.get('created_at'), + 'latest_time': item.get('created_at') + } + + # 添加type到集合中 + item_type = item.get('type') + if item_type: + text_groups[text]['types'].add(item_type) + + # 保留最新的时间(如果有的话) + current_time = item.get('created_at') + if current_time and (not text_groups[text]['latest_time'] or + self._is_later_time(current_time, text_groups[text]['latest_time'])): + text_groups[text]['latest_time'] = current_time + + # 转换为最终格式 + result = [] + for text, group_data in text_groups.items(): + merged_item = { + 'text': text, + 'type': ', '.join(sorted(group_data['types'])), # 合并多个type + 'created_at': group_data['latest_time'] + } + result.append(merged_item) + + # 按时间排序(最新的在前) + result.sort(key=lambda x: x.get('created_at', ''), reverse=True) + + return result + + def _is_later_time(self, time1: str, time2: str) -> bool: + """ + 比较两个时间字符串,判断time1是否晚于time2 + + Args: + time1: 时间字符串1 + time2: 时间字符串2 + + Returns: + time1是否晚于time2 + """ + try: + if not time1 or not time2: + return bool(time1) # 如果time2为空,time1存在就算更晚 + + # 简单的字符串比较(适用于ISO格式的时间) + return time1 > time2 + except Exception: + return False + + def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]: + """ + 处理字段值,支持字符串、列表等类型 + + Args: + value: 字段值 + field_name: 字段名称(用于日志) + + Returns: + 处理后的字典列表 + """ + processed_values = [] + + try: + if isinstance(value, list): + # 如果是列表,处理每个元素 + for item in value: + if self._is_valid_item(item): + processed_item = self._process_single_item(item) + if processed_item: + processed_values.append(processed_item) + elif isinstance(value, dict): + # 如果是字典,直接处理 + if self._is_valid_item(value): + processed_item = self._process_single_item(value) + if processed_item: + processed_values.append(processed_item) + elif isinstance(value, str): + # 如果是字符串,转换为字典格式 + if value.strip() != '' and "MemorySummaryChunk" not in value: + processed_values.append({ + 'text': value, + 'type': field_name, + 'created_at': None + }) + elif value is not None: + # 其他类型转换为字符串 + str_value = str(value) + if str_value.strip() != '' and "MemorySummaryChunk" not in str_value: + processed_values.append({ + 'text': str_value, + 'type': field_name, + 'created_at': None + }) + except Exception as e: + logger.warning(f"处理字段 {field_name} 的值时出错: {e}, 值类型: {type(value)}, 值: {value}") + + return processed_values + + def _is_valid_item(self, item: Any) -> bool: + """ + 检查项目是否有效 + + Args: + item: 要检查的项目 + + Returns: + 是否有效 + """ + if item is None: + return False + + if isinstance(item, dict): + text = item.get('text') + return (text is not None and + str(text).strip() != '' and + "MemorySummaryChunk" not in str(text)) + + return (str(item).strip() != '' and + "MemorySummaryChunk" not in str(item)) + + def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + 处理单个项目 + + Args: + item: 要处理的项目字典 + + Returns: + 处理后的项目字典 + """ + try: + text = item.get('text') + created_at = item.get('created_at') + item_type = item.get('type', '未知类型') + + # 转换Neo4j时间格式 + formatted_time = self._convert_neo4j_datetime(created_at) + + return { + 'text': text, + 'type': item_type, + 'created_at': formatted_time + } + except Exception as e: + logger.warning(f"处理单个项目时出错: {e}, 项目: {item}") + return None + + def _convert_neo4j_datetime(self, dt: Any) -> str: + """ + 转换Neo4j时间格式为标准时间字符串 + + Args: + dt: Neo4j时间对象或其他时间格式 + + Returns: + 格式化的时间字符串 + """ + if dt is None: + return None + + try: + # 处理Neo4j DateTime对象 + if isinstance(dt, Neo4jDateTime): + return dt.iso_format().replace('T', ' ').split('.')[0] + + # 处理其他neo4j时间类型 + if hasattr(dt, 'iso_format'): + return dt.iso_format().replace('T', ' ').split('.')[0] + + # 处理字符串格式的时间 + if isinstance(dt, str): + # 尝试解析ISO格式 + try: + parsed_dt = datetime.fromisoformat(dt.replace('Z', '+00:00')) + return parsed_dt.strftime("%Y-%m-%d %H:%M:%S") + except ValueError: + return dt + + # 其他情况直接转换为字符串 + return str(dt) + + except Exception as e: + logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}") + return str(dt) if dt is not None else None + + + + + async def close(self): + """关闭数据库连接""" + await self.connector.close() + + + +class MemoryEmotion: + def __init__(self, id: str, table: str): + self.id = id + self.table = table + self.connector = Neo4jConnector() + + def _convert_neo4j_types(self, obj: Any) -> Any: + """ + 递归转换Neo4j特殊类型为可序列化的Python类型 + """ + if isinstance(obj, Neo4jDateTime): + # 转换为用户友好的日期格式 + return self._format_datetime(obj.iso_format()) + elif hasattr(obj, '__class__') and 'neo4j' in str(obj.__class__): + if hasattr(obj, 'iso_format'): + return self._format_datetime(obj.iso_format()) + elif hasattr(obj, '__str__'): + return str(obj) + else: + return repr(obj) + elif isinstance(obj, dict): + return {k: self._convert_neo4j_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._convert_neo4j_types(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(self._convert_neo4j_types(item) for item in obj) + else: + return obj + + def _format_datetime(self, iso_string: str) -> str: + """ + 将ISO格式的日期时间字符串转换为用户友好的格式 + + Args: + iso_string: ISO格式的日期时间字符串,如 "2026-01-07T13:40:33.679530" + + Returns: + 格式化后的日期时间字符串,如 "2026-01-07 13:40:33" + """ + try: + # 解析ISO格式的日期时间 + dt = datetime.fromisoformat(iso_string.replace('Z', '+00:00')) + # 返回用户友好的格式:YYYY-MM-DD HH:MM:SS + return dt.strftime("%Y.%m") + except (ValueError, AttributeError): + # 如果解析失败,返回原始字符串 + return iso_string + + async def get_emotion(self) -> Dict[str, Any]: + """ + 获取情绪随时间变化数据 + + Returns: + 包含情绪数据的字典 + """ + try: + logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}") + + if self.table == 'Statement': + results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id) + elif self.table == 'ExtractedEntity': + results = await self.connector.execute_query(Memory_Space_Emotion_ExtractedEntity, id=self.id) + else: + # MemorySummary/Chunk类型查询 + results = await self.connector.execute_query(Memory_Space_Emotion_MemorySummary, id=self.id) + + # 处理查询结果 + emotion_data = self._process_emotion_results(results) + + # 转换Neo4j类型 + final_data = self._convert_neo4j_types(emotion_data) + + logger.info(f"成功获取 {len(final_data)} 条情绪数据") + + return final_data + + except Exception as e: + logger.error(f"获取情绪数据失败: {str(e)}") + return e + + def _process_emotion_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 处理情绪查询结果,按emotion_type和created_at分组并累加emotion_intensity + + Args: + results: Neo4j查询结果 + + Returns: + 处理后的情绪数据列表,相同emotion_type和created_at的记录会合并并累加intensity + """ + length_data=[] + from collections import defaultdict + + # 用于按(emotion_type, created_at)分组累加intensity + emotion_groups = defaultdict(float) + + # 检查results是否为空或不是列表 + if not results or not isinstance(results, list): + logger.warning(f"情绪查询结果为空或格式不正确: {type(results)}") + return [] + + for record in results: + # 检查record是否为字典类型 + if not isinstance(record, dict): + logger.warning(f"跳过非字典类型的记录: {type(record)} - {record}") + continue + + # 获取创建时间并格式化 + created_at = record.get('created_at') + formatted_created_at = created_at + + # 如果created_at是字符串格式,尝试格式化 + if isinstance(created_at, str): + formatted_created_at = self._format_datetime(created_at) + + emotion_type = record.get('emotion_type') + emotion_intensity = record.get('emotion_intensity') + if emotion_type !=None: + length_data.append(emotion_intensity) + + + if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None: + # 使用(emotion_type, created_at)作为分组键 + group_key = (emotion_type, formatted_created_at) + + # 累加emotion_intensity + try: + emotion_groups[group_key] += float(emotion_intensity) + except (ValueError, TypeError): + logger.warning(f"无法转换emotion_intensity为数字: {emotion_intensity}") + continue + # 转换为最终格式 + emotion_data = [ + { + 'emotion_intensity': round(intensity / len(length_data) * 100, 2), + 'emotion_type': emotion_type, + 'created_at': created_at + } + for (emotion_type, created_at), intensity in emotion_groups.items() + ] + + # 按时间排序(最新的在前) + emotion_data.sort(key=lambda x: x.get('created_at', ''), reverse=True) + + + return emotion_data + + async def close(self): + """关闭数据库连接""" + await self.connector.close() + + +class MemoryInteraction: + def __init__(self, id: str, table: str): + self.id = id + self.table = table + self.connector = Neo4jConnector() + + def _convert_neo4j_types(self, obj: Any) -> Any: + """ + 递归转换Neo4j特殊类型为可序列化的Python类型 + """ + if isinstance(obj, Neo4jDateTime): + # 转换为用户友好的日期格式 + return self._format_datetime(obj.iso_format()) + elif hasattr(obj, '__class__') and 'neo4j' in str(obj.__class__): + if hasattr(obj, 'iso_format'): + return self._format_datetime(obj.iso_format()) + elif hasattr(obj, '__str__'): + return str(obj) + else: + return repr(obj) + elif isinstance(obj, dict): + return {k: self._convert_neo4j_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._convert_neo4j_types(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(self._convert_neo4j_types(item) for item in obj) + else: + return obj + + def _format_datetime(self, iso_string: str) -> str: + """ + 将ISO格式的日期时间字符串转换为用户友好的格式 + + Args: + iso_string: ISO格式的日期时间字符串,如 "2026-01-07T13:40:33.679530" + + Returns: + 格式化后的日期时间字符串,如 "2026-01-07 13:40:33" + """ + try: + # 解析ISO格式的日期时间 + dt = datetime.fromisoformat(iso_string.replace('Z', '+00:00')) + # 返回用户友好的格式:YYYY-MM-DD HH:MM:SS + return dt.strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, AttributeError): + # 如果解析失败,返回原始字符串 + return iso_string + + + async def get_interaction_frequency(self) -> Dict[str, Any]: + """ + 获取交互频率数据 + + Returns: + 包含交互数据的字典 + """ + try: + logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}") + + ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) + if ori_data!=[]: + # name = ori_data[0]['name'] + group_id = ori_data[0]['group_id'] + Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) + if not Space_User: + return '不存在用户' + user_id=Space_User[0]['id'] + + results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id) + + + + # 处理查询结果 + interaction_data = self._process_interaction_results(results) + + # 转换Neo4j类型 + final_data = self._convert_neo4j_types(interaction_data) + + logger.info(f"成功获取 {len(final_data)} 条交互数据") + + return final_data + return [] + + except Exception as e: + logger.error(f"获取交互数据失败: {str(e)}") + return e + + def _process_interaction_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 处理交互查询结果,按季度统计交互频率 + + Args: + results: Neo4j查询结果 + + Returns: + 按季度统计的交互数据列表,格式: [{"created_at": "2026Q1", "count": 3}] + """ + from collections import defaultdict + from datetime import datetime + + # 用于按季度分组计数 + quarterly_counts = defaultdict(int) + + for record in results: + # 过滤掉statement为None的记录 + if not isinstance(record, dict) or record.get('statement') is None: + continue + + created_at = record.get('created_at') + if not created_at: + continue + + try: + # 处理不同类型的时间格式 + if isinstance(created_at, str): + # 解析ISO格式时间字符串 + dt = datetime.fromisoformat(created_at.replace('Z', '+00:00')) + elif hasattr(created_at, 'year') and hasattr(created_at, 'month'): + # 处理Neo4j DateTime对象 + dt = datetime(created_at.year, created_at.month, created_at.day) + else: + continue + # 计算季度 + quarter = (dt.month - 1) // 3 + 1 + quarter_key = f"{dt.year}.Q{quarter}" + # 增加该季度的计数 + quarterly_counts[quarter_key] += 1 + + except (ValueError, AttributeError) as e: + logger.warning(f"解析时间失败: {e}, 原始值: {created_at}") + continue + + # 转换为所需格式并按时间排序 + interaction_data = [ + {"created_at": quarter, "count": count} + for quarter, count in quarterly_counts.items() + ] + + # 按季度排序(最新的在前) + interaction_data.sort(key=lambda x: x["created_at"], reverse=True) + + return interaction_data + + async def close(self): + """关闭数据库连接""" + await self.connector.close() diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 30a84b25..8979682d 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -11,7 +11,7 @@ """ from typing import Optional, Dict, Any, Tuple -from datetime import datetime +from datetime import datetime, timezone from sqlalchemy.orm import Session @@ -24,18 +24,54 @@ from app.core.memory.storage_services.forgetting_engine.config_utils import ( ) from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.forgetting_cycle_history_repository import ForgettingCycleHistoryRepository # 获取API专用日志器 api_logger = get_api_logger() +def convert_neo4j_datetime_to_python(value: Any) -> Optional[datetime]: + """ + 将 Neo4j DateTime 对象转换为 Python datetime 对象 + + Args: + value: Neo4j DateTime 对象、Python datetime 对象或字符串 + + Returns: + Python datetime 对象或 None + """ + if value is None: + return None + + try: + # Neo4j DateTime 对象 + if hasattr(value, 'to_native'): + return value.to_native() + # Python datetime 对象 + elif isinstance(value, datetime): + return value + # 字符串格式 + elif isinstance(value, str): + if value.endswith('Z'): + return datetime.fromisoformat(value.replace('Z', '+00:00')) + else: + return datetime.fromisoformat(value) + # 其他类型,尝试转换为字符串 + else: + return datetime.fromisoformat(str(value).replace('Z', '+00:00')) + except Exception as e: + api_logger.warning(f"转换时间失败: {value} (类型: {type(value).__name__}), 错误: {e}") + return None + + class MemoryForgetService: """遗忘引擎服务类""" def __init__(self): """初始化服务""" self.config_repository = DataConfigRepository() + self.history_repository = ForgettingCycleHistoryRepository() def _get_neo4j_connector(self) -> Neo4jConnector: """ @@ -161,10 +197,101 @@ class MemoryForgetService: 'low_activation_nodes': 0 } + async def _get_pending_forgetting_nodes( + self, + connector: Neo4jConnector, + group_id: str, + forgetting_threshold: float, + min_days_since_access: int, + limit: int = 20 + ) -> list[Dict[str, Any]]: + """ + 获取待遗忘节点列表 + + 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) + + Args: + connector: Neo4j 连接器 + group_id: 组ID + forgetting_threshold: 遗忘阈值 + min_days_since_access: 最小未访问天数 + limit: 返回节点数量限制 + + Returns: + list: 待遗忘节点列表 + """ + from datetime import timedelta + + # 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区) + min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) + min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') + + query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + AND n.group_id = $group_id + AND n.activation_value IS NOT NULL + AND n.activation_value < $threshold + AND n.last_access_time IS NOT NULL + AND datetime(n.last_access_time) < datetime($min_access_time_str) + RETURN + elementId(n) as node_id, + labels(n)[0] as node_type, + CASE + WHEN n:Statement THEN n.statement + WHEN n:ExtractedEntity THEN n.name + WHEN n:MemorySummary THEN n.content + ELSE '' + END as content_summary, + n.activation_value as activation_value, + n.last_access_time as last_access_time + ORDER BY n.activation_value ASC + LIMIT $limit + """ + + params = { + 'group_id': group_id, + 'threshold': forgetting_threshold, + 'min_access_time_str': min_access_time_str, + 'limit': limit + } + + results = await connector.execute_query(query, **params) + + pending_nodes = [] + for result in results: + # 将节点类型标签转换为小写 + node_type_label = result['node_type'].lower() + if node_type_label == 'extractedentity': + node_type_label = 'entity' + elif node_type_label == 'memorysummary': + node_type_label = 'summary' + + # 将 Neo4j DateTime 对象转换为时间戳 + last_access_time = result['last_access_time'] + last_access_dt = convert_neo4j_datetime_to_python(last_access_time) + # 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差 + if last_access_dt: + if last_access_dt.tzinfo is None: + last_access_dt = last_access_dt.replace(tzinfo=timezone.utc) + last_access_timestamp = int(last_access_dt.timestamp()) + else: + last_access_timestamp = 0 + + pending_nodes.append({ + 'node_id': str(result['node_id']), + 'node_type': node_type_label, + 'content_summary': result['content_summary'] or '', + 'activation_value': result['activation_value'], + 'last_access_time': last_access_timestamp + }) + + return pending_nodes + async def trigger_forgetting_cycle( self, db: Session, - group_id: Optional[str] = None, + group_id: str, max_merge_batch_size: Optional[int] = None, min_days_since_access: Optional[int] = None, config_id: Optional[int] = None @@ -176,10 +303,10 @@ class MemoryForgetService: Args: db: 数据库会话 - group_id: 组ID(可选) + group_id: 组ID(即终端用户ID,必填) max_merge_batch_size: 最大融合批次大小(可选) min_days_since_access: 最小未访问天数(可选) - config_id: 配置ID(可选) + config_id: 配置ID(必填,由控制器层通过 group_id 获取) Returns: dict: 遗忘报告 @@ -187,6 +314,9 @@ class MemoryForgetService: # 获取遗忘引擎组件 _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + # 记录执行开始时间 + execution_time = datetime.now() + # 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取) report = await forgetting_scheduler.run_forgetting_cycle( group_id=group_id, @@ -202,6 +332,58 @@ class MemoryForgetService: f"耗时 {report['duration_seconds']:.2f} 秒" ) + # 获取当前的激活值统计(用于记录历史) + try: + connector = forgetting_scheduler.connector + stats_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) + AND n.group_id = $group_id + RETURN + count(n) as total_nodes, + avg(n.activation_value) as average_activation, + sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes + """ + + stats_results = await connector.execute_query( + stats_query, + group_id=group_id, + threshold=config['forgetting_threshold'] + ) + + if stats_results: + stats = stats_results[0] + total_nodes = stats['total_nodes'] or 0 + average_activation = stats['average_activation'] + low_activation_nodes = stats['low_activation_nodes'] or 0 + else: + total_nodes = 0 + average_activation = None + low_activation_nodes = 0 + + # 保存历史记录到数据库 + self.history_repository.create( + db=db, + end_user_id=group_id, + execution_time=execution_time, + merged_count=report['merged_count'], + failed_count=report['failed_count'], + average_activation_value=average_activation, + total_nodes=total_nodes, + low_activation_nodes=low_activation_nodes, + duration_seconds=report['duration_seconds'], + trigger_type='manual' + ) + + api_logger.info( + f"已保存遗忘周期历史记录: end_user_id={group_id}, " + f"merged_count={report['merged_count']}" + ) + + except Exception as e: + # 记录历史失败不应影响主流程 + api_logger.error(f"保存遗忘周期历史记录失败: {str(e)}") + return report def read_forgetting_config( @@ -337,7 +519,8 @@ class MemoryForgetService: 'nodes_without_activation': result['nodes_without_activation'] or 0, 'average_activation_value': result['average_activation'], 'low_activation_nodes': result['low_activation_nodes'] or 0, - 'timestamp': datetime.now().isoformat() + 'forgetting_threshold': forgetting_threshold, + 'timestamp': int(datetime.now().timestamp()) } else: activation_metrics = { @@ -346,7 +529,8 @@ class MemoryForgetService: 'nodes_without_activation': 0, 'average_activation_value': None, 'low_activation_nodes': 0, - 'timestamp': datetime.now().isoformat() + 'forgetting_threshold': forgetting_threshold, + 'timestamp': int(datetime.now().timestamp()) } # 收集节点类型分布 @@ -395,19 +579,95 @@ class MemoryForgetService: 'chunk_count': 0 } - # 构建统计信息(不包含监控历史数据) + # 获取最近7个日期的历史趋势数据(每天取最后一次执行) + recent_trends = [] + try: + if group_id: + # 查询所有历史记录 + history_records = self.history_repository.get_recent_by_end_user( + db=db, + end_user_id=group_id + ) + + # 按日期分组(一天可能有多次执行,取最后一次) + from collections import OrderedDict + daily_records = OrderedDict() + + # 遍历记录(已按时间降序),每个日期只保留第一次遇到的(即最后一次执行) + for record in history_records: + # 提取日期(格式: "1/1", "1/2")- 跨平台兼容 + month = record.execution_time.month + day = record.execution_time.day + date_str = f"{month}/{day}" + + # 如果这个日期还没有记录,添加它(这是该日期最后一次执行) + if date_str not in daily_records: + daily_records[date_str] = record + + # 如果已经有7个不同的日期,停止 + if len(daily_records) >= 7: + break + + # 构建趋势数据点(按时间从旧到新排序) + sorted_dates = sorted( + daily_records.items(), + key=lambda x: x[1].execution_time + ) + + for date_str, record in sorted_dates: + recent_trends.append({ + 'date': date_str, + 'merged_count': record.merged_count, + 'average_activation': record.average_activation_value, + 'total_nodes': record.total_nodes, + 'execution_time': int(record.execution_time.timestamp()) + }) + + api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据") + + except Exception as e: + api_logger.error(f"获取历史趋势数据失败: {str(e)}") + # 失败时返回空列表,不影响主流程 + + # 获取待遗忘节点列表(前20个满足遗忘条件的节点) + pending_nodes = [] + try: + if group_id: + # 验证 min_days_since_access 配置值 + min_days = config.get('min_days_since_access') + if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: + api_logger.warning( + f"min_days_since_access 配置无效: {min_days}, 使用默认值 7" + ) + min_days = 7 + + pending_nodes = await self._get_pending_forgetting_nodes( + connector=connector, + group_id=group_id, + forgetting_threshold=forgetting_threshold, + min_days_since_access=int(min_days), + limit=20 + ) + + api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") + + except Exception as e: + api_logger.error(f"获取待遗忘节点失败: {str(e)}") + # 失败时返回空列表,不影响主流程 + + # 构建统计信息 stats = { 'activation_metrics': activation_metrics, 'node_distribution': node_distribution, - 'consistency_check': None, # 不再提供一致性检查 - 'nodes_merged_total': 0, # 不再跟踪累计融合数 - 'recent_cycles': [], # 不再提供历史记录 - 'timestamp': datetime.now().isoformat() + 'recent_trends': recent_trends, + 'pending_nodes': pending_nodes, + 'timestamp': int(datetime.now().timestamp()) } api_logger.info( f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " - f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}" + f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " + f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" ) return stats diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py new file mode 100644 index 00000000..5fafe48d --- /dev/null +++ b/api/app/services/memory_perceptual_service.py @@ -0,0 +1,167 @@ +import uuid +from typing import Dict, Any, Optional + +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.models.memory_perceptual_model import PerceptualType, FileStorageType +from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository +from app.schemas.memory_perceptual_schema import ( + PerceptualQuerySchema, + PerceptualTimelineResponse, + PerceptualMemoryItem, + AudioModal, Content, VideoModal, TextModal +) + +business_logger = get_business_logger() + + +class MemoryPerceptualService: + def __init__(self, db: Session): + self.db = db + self.repository = MemoryPerceptualRepository(db) + + def get_memory_count(self, end_user_id: uuid.UUID) -> Dict[str, Any]: + """Retrieve perceptual memory statistics for a user.""" + business_logger.info(f"Fetching perceptual memory statistics: end_user_id={end_user_id}") + try: + total_count = self.repository.get_count_by_user_id(end_user_id=end_user_id) + + vision_count = self.repository.get_count_by_type(end_user_id, PerceptualType.VISION) + audio_count = self.repository.get_count_by_type(end_user_id, PerceptualType.AUDIO) + text_count = self.repository.get_count_by_type(end_user_id, PerceptualType.TEXT) + conversation_count = self.repository.get_count_by_type(end_user_id, PerceptualType.CONVERSATION) + + stats = { + "total": total_count, + "by_type": { + "vision": vision_count, + "audio": audio_count, + "text": text_count, + "conversation": conversation_count + } + } + + business_logger.info(f"Memory statistics fetched successfully: total={total_count}") + return stats + + except Exception as e: + business_logger.error(f"Failed to fetch memory statistics: {str(e)}") + raise BusinessException(f"Failed to fetch memory statistics: {str(e)}", BizCode.DB_ERROR) + + def _get_latest_memory_by_type( + self, + end_user_id: uuid.UUID, + perceptual_type: PerceptualType + ) -> Optional[dict[str, Any]]: + """Internal helper to retrieve the latest memory by type.""" + business_logger.info(f"Fetching latest {perceptual_type.name.lower()} memory: end_user_id={end_user_id}") + try: + memories = self.repository.get_by_type( + end_user_id=end_user_id, + perceptual_type=perceptual_type, + limit=1, + offset=0 + ) + if not memories: + business_logger.info(f"No {perceptual_type.name.lower()} memory found: end_user_id={end_user_id}") + return None + + memory = memories[0] + meta_data = memory.meta_data or {} + modalities = meta_data.get("modalities") + content = meta_data.get("content") + + if not modalities: + raise BusinessException(f"Modalities not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR) + if not content: + raise BusinessException(f"Content not defined, perceptual memory_id={memory.id}", BizCode.DB_ERROR) + content = Content(**content) + match perceptual_type: + case PerceptualType.VISION: + modal = VideoModal(**modalities) + case PerceptualType.AUDIO: + modal = AudioModal(**modalities) + case PerceptualType.TEXT: + modal = TextModal(**modalities) + case _: + raise BusinessException("Unsupported perceptual type", BizCode.DB_ERROR) + detail = modal.model_dump() + + result = { + "id": str(memory.id), + "file_name": memory.file_name, + "file_path": memory.file_path, + "storage_type": memory.storage_service, + "summary": memory.summary, + "keywords": content.keywords, + "topic": content.topic, + "domain": content.domain, + "created_time": int(memory.created_time.timestamp()*1000), + **detail + } + + business_logger.info( + f"Latest {perceptual_type.name.lower()} memory retrieved successfully: file={memory.file_name}") + return result + + except Exception as e: + business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}") + raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", + BizCode.DB_ERROR) + + def get_latest_visual_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.VISION) + + def get_latest_audio_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.AUDIO) + + def get_latest_text_memory(self, end_user_id: uuid.UUID) -> Optional[Dict[str, Any]]: + return self._get_latest_memory_by_type(end_user_id, PerceptualType.TEXT) + + def get_time_line(self, end_user_id: uuid.UUID, query: PerceptualQuerySchema) -> PerceptualTimelineResponse: + """Retrieve a timeline of perceptual memories for a user.""" + business_logger.info(f"Fetching perceptual memory timeline: " + f"end_user_id={end_user_id}, filter={query.filter}") + + try: + if query.page < 1: + raise BusinessException("Page number must be greater than 0", BizCode.INVALID_PARAMETER) + if query.page_size < 1 or query.page_size > 100: + raise BusinessException("Page size must be between 1 and 100", BizCode.INVALID_PARAMETER) + + total_count, memories = self.repository.get_timeline(end_user_id, query) + + memory_items = [] + for memory in memories: + memory_item = PerceptualMemoryItem( + id=memory.id, + perceptual_type=PerceptualType(memory.perceptual_type), + file_path=memory.file_path, + file_name=memory.file_name, + file_ext=memory.file_ext, + summary=memory.summary, + created_time=int(memory.created_time.timestamp()*1000), + storage_type=FileStorageType(memory.storage_service), + ) + memory_items.append(memory_item) + + timeline_response = PerceptualTimelineResponse( + total=total_count, + page=query.page, + page_size=query.page_size, + total_pages=(total_count + query.page_size - 1) // query.page_size, + memories=memory_items + ) + + business_logger.info(f"Perceptual memory timeline retrieved successfully: " + f"total={total_count}, returned={len(memories)}") + return timeline_response + + except BusinessException: + raise + except Exception as e: + business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") + raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) diff --git a/api/app/services/memory_short_service.py b/api/app/services/memory_short_service.py new file mode 100644 index 00000000..fa3870f0 --- /dev/null +++ b/api/app/services/memory_short_service.py @@ -0,0 +1,56 @@ + +from app.core.logging_config import get_api_logger +from app.db import get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.repositories.memory_short_repository import ShortTermMemoryRepository + + +api_logger = get_api_logger() +db=next(get_db()) +class ShortService: + def __init__(self, end_user_id): + self.short_repo = ShortTermMemoryRepository(db) + self.end_user_id = end_user_id + + def get_short_databasets(self): + short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) + short_result = [] + for memory in short_memories: + deep_expanded = {} # Create a new dictionary for each memory + messages = memory.messages + aimessages = memory.aimessages + retrieved_content = memory.retrieved_content or [] + + api_logger.debug(f"Retrieved content: {retrieved_content}") + + retrieval_source = [] + for item in retrieved_content: + if isinstance(item, dict): + for key, values in item.items(): + retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"}) + + deep_expanded['retrieval'] = retrieval_source + deep_expanded['message'] = messages # 修正拼写错误 + deep_expanded['answer'] = aimessages + short_result.append(deep_expanded) + return short_result + def get_short_count(self): + short_count = self.short_repo.count_by_user_id(self.end_user_id) + return short_count + +class LongService: + def __init__(self, end_user_id): + self.long_repo = LongTermMemoryRepository(db) + self.end_user_id = end_user_id + def get_long_databasets(self): + # 获取长期记忆数据 + long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) + + long_result = [] + for long_memory in long_memories: + if long_memory.retrieved_content: + for memory_item in long_memory.retrieved_content: + if isinstance(memory_item, dict): + for key, values in memory_item.items(): + long_result.append({"query": key, "retrieval": values}) + return long_result diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 55d96082..9cac26ec 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -89,11 +89,15 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) value = item[field] dt = None - # 如果是 datetime 对象,直接使用 - if isinstance(value, datetime): + # 处理不同类型的时间值 + if hasattr(value, 'to_native'): + # Neo4j DateTime 对象 + dt = value.to_native() + elif isinstance(value, datetime): + # Python datetime 对象 dt = value - # 如果是字符串,先解析 elif isinstance(value, str): + # 字符串格式 try: dt = datetime.fromisoformat(value.replace('Z', '+00:00')) except Exception: @@ -185,7 +189,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, - "llm": config.llm, "enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise, "enable_llm_disambiguation": config.enable_llm_disambiguation, "deep_retrieval": config.deep_retrieval, diff --git a/api/app/services/multi_agent_handoffs_integration.py b/api/app/services/multi_agent_handoffs_integration.py index d35252e8..38fd681a 100644 --- a/api/app/services/multi_agent_handoffs_integration.py +++ b/api/app/services/multi_agent_handoffs_integration.py @@ -13,16 +13,17 @@ from app.schemas.multi_agent_schema import MultiAgentRunRequest from app.core.logging_config import get_business_logger from app.core.exceptions import BusinessException from app.core.error_codes import BizCode +from app.services.multi_agent_service import MultiAgentService logger = get_business_logger() class MultiAgentHandoffsService: """Multi-Agent Handoffs 服务 - 扩展现有的 Multi-Agent Service""" - - def __init__(self, db: Session, multi_agent_service): + + def __init__(self, db: Session, multi_agent_service:MultiAgentService): """初始化服务 - + Args: db: 数据库会话 multi_agent_service: 现有的 MultiAgentService 实例 @@ -30,25 +31,25 @@ class MultiAgentHandoffsService: self.db = db self.multi_agent_service = multi_agent_service self.handoff_manager = get_handoff_manager() - + logger.info("Multi-Agent Handoffs 服务初始化完成") - + async def run_with_handoffs( self, app_id: uuid.UUID, request: MultiAgentRunRequest ) -> Dict[str, Any]: """运行支持 handoffs 的多 Agent 任务 - + Args: app_id: 应用 ID request: 运行请求 - + Returns: 执行结果 """ start_time = time.time() - + try: # 1. 获取配置 config = self.multi_agent_service.get_config(app_id) @@ -57,23 +58,25 @@ class MultiAgentHandoffsService: "多 Agent 配置不存在", BizCode.RESOURCE_NOT_FOUND ) - + # 2. 检查是否启用 handoffs execution_config = config.execution_config or {} + print("="*50) + print(execution_config) enable_handoffs = execution_config.get("enable_handoffs", False) - + if not enable_handoffs: # 降级到普通模式 logger.info("Handoffs 未启用,使用普通模式") return await self.multi_agent_service.run(app_id, request) - + # 3. 创建协作编排器 orchestrator = CollaborativeOrchestrator( db=self.db, config=config, handoff_manager=self.handoff_manager ) - + # 4. 执行协作 result = await orchestrator.execute_with_handoffs( message=request.message, @@ -81,11 +84,11 @@ class MultiAgentHandoffsService: user_id=request.user_id, variables=request.variables ) - + # 5. 增强结果 result["mode"] = "handoffs" result["elapsed_time"] = time.time() - start_time - + logger.info( "Handoffs 执行完成", extra={ @@ -95,27 +98,27 @@ class MultiAgentHandoffsService: "elapsed_time": result["elapsed_time"] } ) - + return result - + except Exception as e: logger.error(f"Handoffs 执行失败: {str(e)}") - + # 降级到普通模式 logger.info("降级到普通模式") return await self.multi_agent_service.run(app_id, request) - + async def run_stream_with_handoffs( self, app_id: uuid.UUID, request: MultiAgentRunRequest ) -> AsyncGenerator[str, None]: """流式运行支持 handoffs 的多 Agent 任务 - + Args: app_id: 应用 ID request: 运行请求 - + Yields: SSE 格式的事件流 """ @@ -125,24 +128,24 @@ class MultiAgentHandoffsService: if not config: yield f"data: {{\"event\": \"error\", \"error\": \"配置不存在\"}}\n\n" return - + # 2. 检查是否启用 handoffs execution_config = config.execution_config or {} enable_handoffs = execution_config.get("enable_handoffs", False) - + if not enable_handoffs: # 降级到普通流式模式 async for event in self.multi_agent_service.run_stream(app_id, request): yield event return - + # 3. 创建协作编排器 orchestrator = CollaborativeOrchestrator( db=self.db, config=config, handoff_manager=self.handoff_manager ) - + # 4. 流式执行 async for event in orchestrator.execute_stream_with_handoffs( message=request.message, @@ -151,27 +154,27 @@ class MultiAgentHandoffsService: variables=request.variables ): yield event - + except Exception as e: logger.error(f"流式 Handoffs 执行失败: {str(e)}") yield f"data: {{\"event\": \"error\", \"error\": \"{str(e)}\"}}\n\n" - + def get_handoff_history( self, conversation_id: str ) -> Optional[Dict[str, Any]]: """获取会话的 handoff 历史 - + Args: conversation_id: 会话 ID - + Returns: Handoff 历史信息 """ state = self.handoff_manager.get_state(conversation_id) if not state: return None - + return { "conversation_id": state.conversation_id, "current_agent_id": state.current_agent_id, @@ -190,27 +193,27 @@ class MultiAgentHandoffsService: "created_at": state.created_at.isoformat(), "updated_at": state.updated_at.isoformat() } - + def clear_handoff_state(self, conversation_id: str): """清除会话的 handoff 状态 - + Args: conversation_id: 会话 ID """ self.handoff_manager.clear_state(conversation_id) logger.info(f"清除 handoff 状态: {conversation_id}") - + async def test_handoff_routing( self, app_id: uuid.UUID, message: str ) -> Dict[str, Any]: """测试 handoff 路由决策(不实际执行) - + Args: app_id: 应用 ID message: 测试消息 - + Returns: 路由决策结果 """ @@ -221,7 +224,7 @@ class MultiAgentHandoffsService: "多 Agent 配置不存在", BizCode.RESOURCE_NOT_FOUND ) - + # 2. 解析 sub agents sub_agents = {} for agent_data in config.sub_agents: @@ -230,37 +233,37 @@ class MultiAgentHandoffsService: sub_agents[str(agent_id)] = { "info": agent_data } - + # 3. 测试路由 test_conversation_id = f"test-{uuid.uuid4()}" - + # 选择初始 Agent initial_agent_id = None message_lower = message.lower() - + for agent_id, agent_data in sub_agents.items(): agent_info = agent_data.get("info", {}) capabilities = agent_info.get("capabilities", []) role = agent_info.get("role", "") - + keywords = capabilities + ([role] if role else []) for keyword in keywords: if keyword.lower() in message_lower: initial_agent_id = agent_id break - + if initial_agent_id: break - + if not initial_agent_id: initial_agent_id = next(iter(sub_agents.keys())) - + # 4. 生成 handoff 工具 handoff_tools = self.handoff_manager.generate_handoff_tools( initial_agent_id, sub_agents ) - + # 5. 检查是否需要 handoff handoff_suggestion = self.handoff_manager.should_handoff( conversation_id=test_conversation_id, @@ -268,7 +271,7 @@ class MultiAgentHandoffsService: message=message, available_agents=sub_agents ) - + return { "message": message, "initial_agent_id": initial_agent_id, diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 85eaaad2..bb788641 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -31,6 +31,10 @@ class MultiAgentOrchestrator: self.config = config self.registry = AgentRegistry(db) + # 兼容处理:旧的 orchestration_mode 值映射到新值 + # collaboration | supervisor 是新值,其他旧值默认使用 supervisor + self._normalized_mode = self._normalize_orchestration_mode(config.orchestration_mode) + # 加载主 Agent # self.master_agent = self._load_agent(config.master_agent_id) # self. config.d @@ -50,33 +54,53 @@ class MultiAgentOrchestrator: # 初始化会话状态管理器 self.state_manager = ConversationStateManager() - # 获取 Master Agent 的模型配置 - if not self.default_model_config_id: - raise BusinessException("Master Agent 缺少模型配置", BizCode.AGENT_CONFIG_MISSING) + # 只有 supervisor 模式才需要 default_model_config_id 和 router + self.master_model_config = None + self.router = None + + if self._normalized_mode == OrchestrationMode.SUPERVISOR: + # 获取 Master Agent 的模型配置 + if not self.default_model_config_id: + raise BusinessException("Supervisor 模式需要配置默认模型", BizCode.AGENT_CONFIG_MISSING) - self.master_model_config = self.db.get(ModelConfig, self.default_model_config_id) - if not self.master_model_config: - raise BusinessException("Master Agent 模型配置不存在", BizCode.AGENT_CONFIG_MISSING) + self.master_model_config = self.db.get(ModelConfig, self.default_model_config_id) + if not self.master_model_config: + raise BusinessException("Master Agent 模型配置不存在", BizCode.AGENT_CONFIG_MISSING) - # 初始化 Master Agent 路由器 - self.router = MasterAgentRouter( - db=db, - master_model_config=self.master_model_config, - model_parameters=self.model_parameters, - sub_agents=self.sub_agents, - state_manager=self.state_manager, - enable_rule_fast_path=config.execution_config.get("enable_rule_fast_path", True) - ) + # 初始化 Master Agent 路由器 + self.router = MasterAgentRouter( + db=db, + master_model_config=self.master_model_config, + model_parameters=self.model_parameters, + sub_agents=self.sub_agents, + state_manager=self.state_manager, + enable_rule_fast_path=config.execution_config.get("enable_rule_fast_path", True) + ) logger.info( "多 Agent 编排器初始化完成", extra={ "config_id": str(config.id), - "model": self.master_model_config.name, - "sub_agent_count": len(self.sub_agents) + "model": self.master_model_config.name if self.master_model_config else None, + "sub_agent_count": len(self.sub_agents), + "orchestration_mode": self._normalized_mode } ) + def _normalize_orchestration_mode(self, mode: str) -> str: + """标准化 orchestration_mode,兼容旧值 + + Args: + mode: 原始的 orchestration_mode 值 + + Returns: + 标准化后的模式:collaboration 或 supervisor + """ + if mode in [OrchestrationMode.SUPERVISOR, "supervisor"]: + return OrchestrationMode.SUPERVISOR + # 其他所有值(包括旧的 sequential、parallel、conditional、loop 和 collaboration)都映射到 collaboration + return OrchestrationMode.COLLABORATION + async def execute_stream( self, message: str, @@ -108,7 +132,7 @@ class MultiAgentOrchestrator: logger.info( "开始执行多 Agent 任务(流式)", extra={ - "mode": self.config.orchestration_mode, + "mode": self._normalized_mode, "message_length": len(message) } ) @@ -116,16 +140,29 @@ class MultiAgentOrchestrator: try: # 发送开始事件 yield self._format_sse_event("start", { - "mode": self.config.orchestration_mode, + "mode": self._normalized_mode, "timestamp": time.time() }) - # 1. 主 Agent 分析任务 - task_analysis = await self._analyze_task(message, variables) - task_analysis["use_llm_routing"] = use_llm_routing - # 2. 根据模式执行(流式) - if self.config.orchestration_mode == OrchestrationMode.CONDITIONAL: + # Collaboration 模式:Agent 之间可以相互 handoff(使用 handoffs_service) + if self._normalized_mode == OrchestrationMode.COLLABORATION: + async for event in self._execute_collaboration_mode_stream( + message, + conversation_id, + user_id, + web_search, + memory, + storage_type, + user_rag_memory_id + ): + yield event + # Supervisor 模式:由主 Agent 统一调度子 Agent + elif self._normalized_mode == OrchestrationMode.SUPERVISOR: + # 1. 主 Agent 分析任务 + task_analysis = await self._analyze_task(message, variables) + task_analysis["use_llm_routing"] = use_llm_routing + async for event in self._execute_conditional_stream( task_analysis, conversation_id, @@ -137,62 +174,10 @@ class MultiAgentOrchestrator: ): yield event else: - # 其他模式暂时使用非流式执行,然后一次性返回 - if self.config.orchestration_mode == OrchestrationMode.SEQUENTIAL: - results = await self._execute_sequential( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - elif self.config.orchestration_mode == OrchestrationMode.PARALLEL: - results = await self._execute_parallel( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - # elif self.config.orchestration_mode == "loop": - # results = await self._execute_loop( - # task_analysis, - # conversation_id, - # user_id, - # web_search, - # memory, - # storage_type, - # user_rag_memory_id - # ) - else: - raise BusinessException( - f"不支持的编排模式: {self.config.orchestration_mode}", - BizCode.INVALID_PARAMETER - ) - - # 整合结果 - final_result = await self._aggregate_results(results) - - # 提取会话 ID - sub_conversation_id = None - if isinstance(results, dict): - sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") - elif isinstance(results, list) and results: - for item in results: - if "result" in item: - sub_conversation_id = item["result"].get("conversation_id") - if sub_conversation_id: - break - - # 发送消息事件 - yield self._format_sse_event("message", { - "content": final_result, - "conversation_id": sub_conversation_id - }) + raise BusinessException( + f"不支持的编排模式: {self._normalized_mode}", + BizCode.INVALID_PARAMETER + ) elapsed_time = time.time() - start_time @@ -205,7 +190,7 @@ class MultiAgentOrchestrator: logger.info( "多 Agent 任务完成(流式)", extra={ - "mode": self.config.orchestration_mode, + "mode": self._normalized_mode, "elapsed_time": elapsed_time } ) @@ -213,7 +198,7 @@ class MultiAgentOrchestrator: except Exception as e: logger.error( "多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self.config.orchestration_mode} + extra={"error": str(e), "mode": self._normalized_mode} ) # 发送错误事件 yield self._format_sse_event("error", { @@ -247,10 +232,23 @@ class MultiAgentOrchestrator: logger.info( "开始执行多 Agent 任务", - extra={"message_length": len(message)} + extra={ + "message_length": len(message), + "mode": self._normalized_mode + } ) try: + # Collaboration 模式:使用 handoffs_service + if self._normalized_mode == OrchestrationMode.COLLABORATION: + return await self._execute_collaboration_mode( + message, + conversation_id, + user_id, + variables + ) + + # Supervisor 模式:由 Master Agent 统一调度 # 1. Master Agent 分析任务并做出决策 task_analysis = await self._analyze_task(message, variables) @@ -1267,7 +1265,8 @@ class MultiAgentOrchestrator: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, web_search=web_search, - memory=memory + memory=memory, + sub_agent=True ): yield event @@ -1292,6 +1291,7 @@ class MultiAgentOrchestrator: conversation_id: 会话 ID user_id: 用户 ID + Returns: 执行结果 """ @@ -1409,6 +1409,148 @@ class MultiAgentOrchestrator: return self._merge_results(results) + async def _execute_collaboration_mode_stream( + self, + message: str, + conversation_id: Optional[uuid.UUID], + user_id: Optional[str], + web_search: bool = False, + memory: bool = True, + storage_type: str = '', + user_rag_memory_id: str = '' + ): + """Collaboration 模式流式执行 - Agent 之间可以相互 handoff + + 使用 handoffs_service 实现 Agent 之间的动态切换 + + Args: + message: 用户消息 + conversation_id: 会话 ID + user_id: 用户 ID + web_search: 是否启用网络搜索 + memory: 是否启用记忆 + storage_type: 存储类型 + user_rag_memory_id: RAG 记忆 ID + + Yields: + SSE 格式的事件流 + """ + from app.services.handoffs_service import ( + convert_multi_agent_config_to_handoffs, + HandoffsService + ) + + try: + # 1. 构建 multi_agent_config 字典 + multi_agent_config = { + "sub_agents": self.config.sub_agents, + "orchestration_mode": self.config.orchestration_mode + } + + # 2. 转换配置(每个 Agent 包含自己的 model_config) + agent_configs = convert_multi_agent_config_to_handoffs( + multi_agent_config, + self.db + ) + + if not agent_configs: + raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING) + + # 3. 创建 HandoffsService + handoffs_service = HandoffsService( + agent_configs=agent_configs, + streaming=True + ) + + # 4. 使用 handoffs_service 的流式聊天 + conv_id = str(conversation_id) if conversation_id else None + + async for event in handoffs_service.chat_stream( + message=message, + conversation_id=conv_id + ): + # handoffs_service 返回的已经是 SSE 格式,直接 yield + yield event + + except Exception as e: + logger.error(f"Collaboration 模式执行失败: {str(e)}", exc_info=True) + yield self._format_sse_event("error", { + "error": str(e), + "timestamp": time.time() + }) + + async def _execute_collaboration_mode( + self, + message: str, + conversation_id: Optional[uuid.UUID], + user_id: Optional[str], + variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Collaboration 模式非流式执行 - Agent 之间可以相互 handoff + + 使用 handoffs_service 实现 Agent 之间的动态切换 + + Args: + message: 用户消息 + conversation_id: 会话 ID + user_id: 用户 ID + variables: 变量参数 + + Returns: + 执行结果 + """ + from app.services.handoffs_service import ( + convert_multi_agent_config_to_handoffs, + HandoffsService + ) + + start_time = time.time() + + try: + # 1. 构建 multi_agent_config 字典 + multi_agent_config = { + "sub_agents": self.config.sub_agents, + "orchestration_mode": self.config.orchestration_mode + } + + # 2. 转换配置(每个 Agent 包含自己的 model_config) + agent_configs = convert_multi_agent_config_to_handoffs( + multi_agent_config, + self.db + ) + + if not agent_configs: + raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING) + + # 3. 创建 HandoffsService + handoffs_service = HandoffsService( + agent_configs=agent_configs, + streaming=False + ) + + # 4. 使用 handoffs_service 的非流式聊天 + conv_id = str(conversation_id) if conversation_id else None + + result = await handoffs_service.chat( + message=message, + conversation_id=conv_id + ) + + elapsed_time = time.time() - start_time + + return { + "message": result.get("response", ""), + "conversation_id": result.get("conversation_id"), + "elapsed_time": elapsed_time, + "strategy": "collaboration", + "active_agent": result.get("active_agent"), + "sub_results": result + } + + except Exception as e: + logger.error(f"Collaboration 模式执行失败: {str(e)}", exc_info=True) + raise + def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str: """格式化 SSE 事件 diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index 709df3ab..1a08a5af 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -264,10 +264,11 @@ class MultiAgentService: if not app: raise ResourceNotFoundException("应用", str(app_id)) - # 2. 验证模型配置 - model_api_key = ModelApiKeyService.get_a_api_key(self.db,data.default_model_config_id) - if not model_api_key: - raise ResourceNotFoundException("模型配置", str(data.default_model_config_id)) + # 2. 验证模型配置(如果提供了) + if data.default_model_config_id: + model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id) + if not model_api_key: + raise ResourceNotFoundException("模型配置", str(data.default_model_config_id)) # 3. 验证子 Agent 存在并获取发布版本 ID for sub_agent in data.sub_agents: @@ -347,13 +348,14 @@ class MultiAgentService: ) return config + # 完全替换配置,但对于数据库 NOT NULL 字段,如果新值是 None 则保留原值 config.default_model_config_id = newConfig.default_model_config_id config.model_parameters = newConfig.model_parameters - config.orchestration_mode = newConfig.orchestration_mode - config.sub_agents = newConfig.sub_agents + config.orchestration_mode = newConfig.orchestration_mode or config.orchestration_mode + config.sub_agents = newConfig.sub_agents if newConfig.sub_agents is not None else config.sub_agents config.routing_rules = newConfig.routing_rules - config.execution_config = newConfig.execution_config - config.aggregation_strategy = newConfig.aggregation_strategy + config.execution_config = newConfig.execution_config if newConfig.execution_config else config.execution_config + config.aggregation_strategy = newConfig.aggregation_strategy or config.aggregation_strategy self.db.commit() self.db.refresh(config) diff --git a/api/app/services/prompt/conversation_summary_system.jinja2 b/api/app/services/prompt/conversation_summary_system.jinja2 new file mode 100644 index 00000000..2947c237 --- /dev/null +++ b/api/app/services/prompt/conversation_summary_system.jinja2 @@ -0,0 +1,50 @@ +{% raw %} +# Role Definition +You are a professional dialogue content summarizer, specializing in extracting core information from multi-turn conversations between users and AI. Your goal is to generate concise, accurate summaries with extended key fields that help users quickly grasp the conversation's theme, key points, and value. + +# Core Rules +- **Mandatory Rules**: + 1. Fully extract explicit user requests (questions/tasks) without omitting key details; + 2. Accurately summarize AI’s core responses (explanations/guidance) aligned with user requests; + 3. Reflect cause-and-effect relationships in multi-turn interactions (follow-up questions, clarifications); + 4. Clearly identify and describe the conversation’s theme, key收获 (takeaways), and other required extended fields. +- **Constraints**: + 1. Do not add unmentioned information or subjective assumptions; + 2. Avoid vague expressions (e.g., "the user asked some questions"); be specific; + 3. For repetitive content (same question multiple times), keep only the initial request and final response. + +# Input Processing +- Reading Order: Chronological sentence-by-sentence reading; +- Priority: User requests > AI responses > interaction logic > theme/takeaway extraction; +- Exception Handling: If the conversation is empty/invalid (only greetings, no substantive content), output "The conversation content is invalid and a summary cannot be generated." + +# Execution Process +1. **Information Extraction**: + - Input: {{conversation}} + - Operation: Label user requests, AI responses, interaction nodes, conversation theme (core topic), and takeaways (key insights/results) sentence by sentence; +2. **Logic Organization**: + - Input: Labeled extracted information + - Operation: Match requests with responses, organize interaction progression, and associate theme/takeaways with core content; +3. **Summary Generation**: + - Input: Organized logical relationships and extended fields + - Operation: Integrate core information, theme, and takeaways into coherent language, ensuring all key elements are covered while removing redundancy. + +# Output Specifications (JSON Format) +- Language: Please strictly output content in the language specified by the tag. +- Structure: JSON object with five fields,: + 1. `theme`: A concise phrase describing the conversation’s core topic (e.g., "inquiry about delivery time rules"); + 2. `summary`: A single sentence including "user request + AI response + interaction logic" (≤150 words); + 3. `takeaways`: A list of brief bullet-point takeaways summarizing the key points from the conversation (e.g., ["User clarified delivery time differences between regular and remote areas"]). + 4. `question`: A list of brief declarative statements summarizing the pitfalls the user encountered during the current conversation.Return an empty list if none are present. + 5. `info_score`: Numerical score (0–100) representing conversation information richness. +- Language Style: Concise, objective, conversational (avoid overly formal terms). + +# Example JSON Output +{ + "theme": string, + "summary": string, + "takeaways": array[string], + "question": array[string] + "info_score": 85 +} +{% endraw %} diff --git a/api/app/services/prompt/conversation_summary_user.jinja2 b/api/app/services/prompt/conversation_summary_user.jinja2 new file mode 100644 index 00000000..51efe34e --- /dev/null +++ b/api/app/services/prompt/conversation_summary_user.jinja2 @@ -0,0 +1,2 @@ +{{ language }} +{{ conversation }} \ No newline at end of file diff --git a/api/app/templates/prompt/prompt_optimizer_system.jinja2 b/api/app/services/prompt/prompt_optimizer_system.jinja2 similarity index 100% rename from api/app/templates/prompt/prompt_optimizer_system.jinja2 rename to api/app/services/prompt/prompt_optimizer_system.jinja2 diff --git a/api/app/templates/prompt/prompt_optimizer_user.jinja2 b/api/app/services/prompt/prompt_optimizer_user.jinja2 similarity index 100% rename from api/app/templates/prompt/prompt_optimizer_user.jinja2 rename to api/app/services/prompt/prompt_optimizer_user.jinja2 diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 482e8213..c6142c01 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -3,9 +3,8 @@ import uuid from typing import Any, AsyncGenerator import json_repair -from langchain_core.prompts import ChatPromptTemplate -from sqlalchemy.orm import Session from jinja2 import Template +from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -166,6 +165,8 @@ class PromptOptimizerService: model_config = self.get_model_config(tenant_id, model_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) + logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}") + # Create LLM instance api_config: ModelApiKey = model_config.api_keys[0] llm = RedBearLLM(RedBearModelConfig( @@ -175,11 +176,11 @@ class PromptOptimizerService: base_url=api_config.api_base ), type=ModelType(model_config.type)) try: - with open('app/templates/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: + with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render() - with open('app/templates/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: + with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: opt_user_prompt = f.read() except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) @@ -203,7 +204,6 @@ class PromptOptimizerService: messages.extend(session_history[:-1]) # last message is current message messages.extend([(RoleType.USER.value, rendered_user_message)]) - logger.info(f"Prompt optimization message: {messages}") buffer = "" prompt_started = False prompt_finished = False @@ -231,9 +231,9 @@ class PromptOptimizerService: if m: prompt_index = m.start() prompt_finished = True - yield {"type": "delta", "content": buffer[idx:prompt_index]} + yield {"content": buffer[idx:prompt_index]} else: - yield {"type": "delta", "content": cache[idx:]} + yield {"content": cache[idx:]} if len(cache) != 0: idx = len(cache) @@ -249,8 +249,9 @@ class PromptOptimizerService: role=RoleType.ASSISTANT, content=desc ) - - yield {"type": "done", "desc": optim_result.get("desc")} + variables = self.parser_prompt_variables(optim_result.get("prompt")) + logger.info(f"Prompt optimization completed, user_id={user_id}, session_id={session_id}") + yield {"desc": optim_result.get("desc"), "variables": variables} @staticmethod def parser_prompt_variables(prompt: str): diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 50cca957..2bb96e53 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -8,7 +8,7 @@ from datetime import datetime from sqlalchemy.orm import Session -from app.core.tools.mcp import MCPClient +from app.core.tools.mcp import MCPToolManager, SimpleMCPClient from app.repositories.tool_repository import ( ToolRepository, BuiltinToolRepository, CustomToolRepository, MCPToolRepository, ToolExecutionRepository @@ -43,6 +43,9 @@ class ToolService: self.db = db self._tool_cache: Dict[str, BaseTool] = {} + # MCP管理器 + self.mcp_tool_manager = MCPToolManager(db) + # 初始化仓储 self.tool_repo = ToolRepository() self.builtin_repo = BuiltinToolRepository() @@ -344,14 +347,16 @@ class ToolService: break if operation_param: - # 有多个操作 + # 有多个操作,为每个操作生成具体参数 methods = [] for operation in operation_param.enum: + # 获取该操作的具体参数 + operation_params = self._get_operation_specific_params(tool_instance, operation) methods.append({ "method_id": f"{config.name}_{operation}", "name": operation, "description": f"{config.description} - {operation}", - "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + "parameters": operation_params }) return methods else: @@ -362,6 +367,243 @@ class ToolService: "description": config.description, "parameters": [p for p in tool_instance.parameters if p.name != "operation"] }] + + def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]: + """获取特定操作的参数列表""" + # 对于datetime_tool,根据操作类型返回相关参数 + if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool': + return self._get_datetime_tool_params(operation) + # 对于json_tool,根据操作类型返回相关参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': + return self._get_json_tool_params(operation) + + # 其他工具的默认处理:返回除operation外的所有参数 + return [{ + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } for param in tool_instance.parameters if param.name != "operation"] + + def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取datetime_tool特定操作的参数""" + if operation == "now": + return [ + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "format": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "convert_timezone": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + elif operation == "timestamp_to_datetime": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + else: + # 默认返回所有参数(除了operation) + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": False + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "calculation", + "type": "string", + "description": "时间计算表达式(如:+1d, -2h, +30m)", + "required": False + } + ] + + def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取json_tool特定操作的参数""" + base_params = [ + { + "name": "input_data", + "type": "string", + "description": "输入数据(JSON字符串、YAML字符串或XML字符串)", + "required": True + } + ] + + if operation == "insert": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "new_value", + "type": "string", + "description": "新值(用于insert操作)", + "required": True + } + ] + elif operation == "replace": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "old_text", + "type": "string", + "description": "要替换的原文本(用于replace操作)", + "required": True + }, + { + "name": "new_text", + "type": "string", + "description": "替换后的新文本(用于replace操作)", + "required": True + } + ] + elif operation == "delete": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + elif operation == "parse": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + + return base_params async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法""" @@ -436,23 +678,75 @@ class ToolService: return [] async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: - """获取MCP工具的方法""" + """获取MCP工具的方法和参数""" mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) if not mcp_config: return [] available_tools = mcp_config.available_tools or [] if not available_tools: - return [] + # 如果没有工具列表,尝试同步 + try: + success, tools, _ = await self.mcp_tool_manager.discover_tools( + mcp_config.server_url, mcp_config.connection_config or {} + ) + if success: + # 转换为新格式 + tool_list = [] + for tool in tools: + if tool.get("name"): + tool_list.append({ + tool["name"]: { + "description": tool.get("description", ""), + "inputSchema": tool.get("inputSchema", {}) + } + }) + mcp_config.available_tools = tool_list + self.db.commit() + available_tools = tool_list + except Exception as e: + logger.error(f"同步MCP工具列表失败: {e}") + return [] methods = [] - for tool_name in available_tools: - methods.append({ - "method_id": tool_name, - "name": tool_name, - "description": f"MCP工具: {tool_name}", - "parameters": [] # MCP工具参数需要动态获取 - }) + + # 处理新格式的available_tools + for tool_item in available_tools: + if isinstance(tool_item, dict): + for tool_name, tool_data in tool_item.items(): + # 解析工具参数 + parameters = [] + input_schema = tool_data.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + for param_name, param_def in properties.items(): + parameters.append({ + "name": param_name, + "type": param_def.get("type", "string"), + "description": param_def.get("description", ""), + "required": param_name in required_fields, + "default": param_def.get("default"), + "enum": param_def.get("enum"), + "minimum": param_def.get("minimum"), + "maximum": param_def.get("maximum") + }) + + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": tool_data.get("description", f"MCP工具: {tool_name}"), + "parameters": parameters + }) + else: + # 兼容旧格式(字符串) + tool_name = str(tool_item) + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": f"MCP工具: {tool_name}", + "parameters": [] + }) return methods @@ -589,10 +883,18 @@ class ToolService: if config.tool_type == ToolType.MCP.value: mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) if mcp_config: + # 处理available_tools显示格式 + available_tools_display = [] + for tool_item in (mcp_config.available_tools or []): + if isinstance(tool_item, dict): + available_tools_display.extend(list(tool_item.keys())) + else: + available_tools_display.append(str(tool_item)) + config_data.update({ "last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None, "health_status": mcp_config.health_status, - "available_tools": mcp_config.available_tools or [] + "available_tools": available_tools_display }) return ToolInfo( @@ -832,71 +1134,70 @@ class ToolService: return {} async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]: - """测试MCP连接""" + """测试MCP连接并自动同步工具列表""" try: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == config.id - ).first() - + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) if not mcp_config: return {"success": False, "message": "MCP配置不存在"} - client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) + # 使用集成的MCP管理器测试连接 + test_result = await self.mcp_tool_manager.test_tool_connection( + mcp_config.server_url, mcp_config.connection_config or {} + ) - if await client.connect(): - try: - # tools = await client.list_tools() - await client.disconnect() + if test_result["success"]: + # 连接成功,自动同步工具列表 + success, tools, error = await self.mcp_tool_manager.discover_tools( + mcp_config.server_url, mcp_config.connection_config or {} + ) - # 更新连接状态 + if success: + # 转换为新格式 + tool_list = [] + tool_names = [] + for tool in tools: + if tool.get("name"): + tool_names.append(tool["name"]) + tool_list.append({ + tool["name"]: { + "description": tool.get("description", ""), + "inputSchema": tool.get("inputSchema", {}) + } + }) + + # 更新数据库 + mcp_config.available_tools = tool_list mcp_config.last_health_check = datetime.now() mcp_config.health_status = "healthy" mcp_config.error_message = None + config.status = ToolStatus.AVAILABLE.value - # 更新工具状态 - self._update_tool_status(config) self.db.commit() return { "success": True, - "message": "MCP连接成功", - # "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} - "details": {"server_url": mcp_config.server_url} + "message": "MCP连接成功并同步工具列表", + "details": { + "server_url": mcp_config.server_url, + "tools_count": len(tool_names), + "tools": tool_names + } } - except Exception as e: - await client.disconnect() - - # 更新错误状态 - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = str(e) - self._update_tool_status(config) - self.db.commit() - - return {"success": False, "message": f"MCP功能测试失败: {str(e)}"} + else: + return {"success": False, "message": f"同步工具失败: {error}"} else: - # 更新连接失败状态 + # 更新错误状态 mcp_config.last_health_check = datetime.now() mcp_config.health_status = "error" - mcp_config.error_message = "连接失败" - self._update_tool_status(config) + mcp_config.error_message = test_result.get("error", "连接失败") + config.status = ToolStatus.ERROR.value self.db.commit() - return {"success": False, "message": "MCP连接失败"} + return test_result except Exception as e: - # 更新异常状态 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == config.id - ).first() - if mcp_config: - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = str(e) - self._update_tool_status(config) - self.db.commit() - - return {"success": False, "message": f"MCP测试异常: {str(e)}"} + logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}") + return {"success": False, "message": f"测试失败: {str(e)}"} @staticmethod async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]: @@ -951,57 +1252,56 @@ class ToolService: # 创建MCP客户端 connection_config = mcp_config.connection_config or {} - - client = MCPClient(mcp_config.server_url, connection_config) - - if await client.connect(): - try: - # 获取工具列表 - tools = await client.list_tools() - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - - # 更新数据库 - mcp_config.available_tools = tool_names - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "healthy" - mcp_config.error_message = None - - # 更新工具状态 - config.status = ToolStatus.AVAILABLE.value - - self.db.commit() - - await client.disconnect() - - return { - "success": True, - "message": "工具列表同步成功", - "tools_count": len(tool_names), - "tools": tool_names - } - - except Exception as e: - await client.disconnect() - - # 更新错误状态 + client = SimpleMCPClient(mcp_config.server_url, connection_config) + + async with client: + # 获取工具列表 + tools = await client.list_tools() + + # 转换为新格式 + tool_list = [] + tool_names = [] + for tool in tools: + if tool.get("name"): + tool_names.append(tool["name"]) + tool_list.append({ + tool["name"]: { + "description": tool.get("description", ""), + "inputSchema": tool.get("inputSchema", {}) + } + }) + + # 更新数据库 + mcp_config.available_tools = tool_list + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "healthy" + mcp_config.error_message = None + + # 更新工具状态 + config.status = ToolStatus.AVAILABLE.value + + self.db.commit() + + return { + "success": True, + "message": "工具列表同步成功", + "tools_count": len(tool_names), + "tools": tool_names + } + + except Exception as e: + # 更新错误状态 + try: + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if mcp_config: mcp_config.last_health_check = datetime.now() mcp_config.health_status = "error" mcp_config.error_message = str(e) config.status = ToolStatus.ERROR.value self.db.commit() - - return {"success": False, "message": f"获取工具列表失败: {str(e)}"} - else: - # 连接失败 - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = "连接失败" - config.status = ToolStatus.ERROR.value - self.db.commit() - - return {"success": False, "message": "MCP连接失败"} - - except Exception as e: + except: + pass + logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}") return {"success": False, "message": f"同步失败: {str(e)}"} diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index bf0375fb..b77a4ada 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -7,7 +7,6 @@ User Memory Service import os import uuid from collections import Counter -from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Optional, Tuple @@ -22,7 +21,298 @@ from sqlalchemy.orm import Session logger = get_logger(__name__) -# Neo4j connector instan +# Neo4j connector instance for analytics functions +_neo4j_connector = Neo4jConnector() + +# Default LLM ID for fallback +DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") + + +# ============================================================================ +# Internal Helper Classes +# ============================================================================ + +class TagClassification(BaseModel): + """Represents the classification of a tag into a specific domain.""" + domain: str = Field( + ..., + description="The domain the tag belongs to, chosen from the predefined list.", + examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"], + ) + + +def _get_llm_client_for_user(user_id: str): + """ + Get LLM client for a specific user based on their config. + + Args: + user_id: User ID to get config for + + Returns: + LLM client instance + """ + with get_db_context() as db: + try: + from app.services.memory_agent_service import get_end_user_connected_config + connected_config = get_end_user_connected_config(user_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + factory = MemoryClientFactory(db) + return factory.get_llm_client(memory_config.llm_model_id) + else: + factory = MemoryClientFactory(db) + return factory.get_llm_client(DEFAULT_LLM_ID) + except Exception as e: + logger.warning(f"Failed to get user connected config, using default LLM: {e}") + factory = MemoryClientFactory(db) + return factory.get_llm_client(DEFAULT_LLM_ID) + + +class MemoryInsightHelper: + """ + Internal helper class for memory insight analysis. + Provides basic data retrieval and analysis functionality. + """ + + def __init__(self, user_id: str): + self.user_id = user_id + self.neo4j_connector = Neo4jConnector() + self.llm_client = _get_llm_client_for_user(user_id) + + async def close(self): + """Close database connection.""" + await self.neo4j_connector.close() + + async def get_domain_distribution(self) -> dict[str, float]: + """Calculate the distribution of memory domains based on hot tags.""" + from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags + + hot_tags = await get_hot_memory_tags(self.user_id) + if not hot_tags: + return {} + + domain_counts = Counter() + for tag, _ in hot_tags: + prompt = f"""请将以下标签归类到最合适的领域中。 + +可选领域及其关键词: +- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等 +- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等 +- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等 +- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等 +- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等 +- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等 +- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等 +- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等 +- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等 +- 其他:确实无法归入以上任何类别的内容 + +标签: {tag} + +分析步骤: +1. 仔细理解标签的核心含义和使用场景 +2. 对比各个领域的关键词,找到最匹配的领域 +3. 特别注意: + - 历史、科学、文化等知识性内容应归类为"学习" + - 学校、课程、考试等正式教育场景应归类为"教育" + - 只有在标签完全不属于上述9个具体领域时,才选择"其他" +4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他" + +请直接返回最合适的领域名称。""" + messages = [ + {"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"}, + {"role": "user", "content": prompt} + ] + classification = await self.llm_client.response_structured( + messages=messages, + response_model=TagClassification, + ) + if classification and hasattr(classification, 'domain') and classification.domain: + domain_counts[classification.domain] += 1 + + total_tags = sum(domain_counts.values()) + if total_tags == 0: + return {} + + domain_distribution = { + domain: count / total_tags for domain, count in domain_counts.items() + } + return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)) + + async def get_active_periods(self) -> list[int]: + """ + Identify the top 2 most active months for the user. + Only returns months if there is valid and diverse time data. + """ + query = """ + MATCH (d:Dialogue) + WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> '' + RETURN d.created_at AS creation_time + """ + records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + + if not records: + return [] + + month_counts = Counter() + valid_dates_count = 0 + for record in records: + creation_time = record.get("creation_time") + if not creation_time: + continue + try: + # 处理 Neo4j DateTime 对象或字符串 + if hasattr(creation_time, 'to_native'): + dt_object = creation_time.to_native() + elif isinstance(creation_time, str): + dt_object = datetime.fromisoformat(creation_time.replace("Z", "+00:00")) + elif isinstance(creation_time, datetime): + dt_object = creation_time + else: + dt_object = datetime.fromisoformat(str(creation_time).replace("Z", "+00:00")) + + month_counts[dt_object.month] += 1 + valid_dates_count += 1 + except (ValueError, TypeError, AttributeError): + continue + + if not month_counts or valid_dates_count == 0: + return [] + + # Check if time distribution is too concentrated (likely batch imported data) + unique_months = len(month_counts) + if unique_months <= 2: + most_common_count = month_counts.most_common(1)[0][1] + if most_common_count / valid_dates_count > 0.8: + return [] + + if unique_months >= 3: + most_common_months = month_counts.most_common(2) + return [month for month, _ in most_common_months] + + if unique_months == 2: + counts = list(month_counts.values()) + ratio = min(counts) / max(counts) + if ratio > 0.3: + most_common_months = month_counts.most_common(2) + return [month for month, _ in most_common_months] + + return [] + + async def get_social_connections(self) -> dict | None: + """Find the user with whom the most memories are shared.""" + query = """ + MATCH (c1:Chunk {group_id: $group_id}) + OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) + OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) + WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL + WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements + WHERE common_statements > 0 + RETURN other_user_id, common_statements + ORDER BY common_statements DESC + LIMIT 1 + """ + records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + if not records or not records[0].get("other_user_id"): + return None + + most_connected_user = records[0]["other_user_id"] + common_memories_count = records[0]["common_statements"] + + time_range_query = """ + MATCH (c:Chunk) + WHERE c.group_id IN [$user_id, $other_user_id] + RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time + """ + time_records = await self.neo4j_connector.execute_query( + time_range_query, + user_id=self.user_id, + other_user_id=most_connected_user + ) + start_year, end_year = "N/A", "N/A" + if time_records and time_records[0]["start_time"]: + start_time = time_records[0]["start_time"] + end_time = time_records[0]["end_time"] + + # 处理 Neo4j DateTime 对象或字符串 + try: + if hasattr(start_time, 'to_native'): + start_year = start_time.to_native().year + elif isinstance(start_time, str): + start_year = datetime.fromisoformat(start_time.replace("Z", "+00:00")).year + elif isinstance(start_time, datetime): + start_year = start_time.year + else: + start_year = datetime.fromisoformat(str(start_time).replace("Z", "+00:00")).year + except Exception: + start_year = "N/A" + + try: + if hasattr(end_time, 'to_native'): + end_year = end_time.to_native().year + elif isinstance(end_time, str): + end_year = datetime.fromisoformat(end_time.replace("Z", "+00:00")).year + elif isinstance(end_time, datetime): + end_year = end_time.year + else: + end_year = datetime.fromisoformat(str(end_time).replace("Z", "+00:00")).year + except Exception: + end_year = "N/A" + + return { + "user_id": most_connected_user, + "common_memories_count": common_memories_count, + "time_range": f"{start_year}-{end_year}", + } + + +class UserSummaryHelper: + """ + Internal helper class for user summary generation. + Provides data retrieval functionality for user summary analysis. + """ + + def __init__(self, user_id: str): + self.user_id = user_id + self.connector = Neo4jConnector() + self.llm = _get_llm_client_for_user(user_id) + + async def close(self): + """Close database connection.""" + await self.connector.close() + + async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]: + """Fetch recent statements authored by the user/group for context.""" + query = ( + "MATCH (s:Statement) " + "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " + "RETURN s.statement AS statement, s.created_at AS created_at " + "ORDER BY created_at DESC LIMIT $limit" + ) + rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) + records = [] + for r in rows: + try: + records.append({ + "statement": r.get("statement", ""), + "created_at": r.get("created_at") + }) + except Exception: + continue + return records + + async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]: + """Get meaningful entities and their frequencies using hot tag logic.""" + from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags + return await get_hot_memory_tags(self.user_id, limit=limit) + + +# ============================================================================ +# Service Class +# ============================================================================ class UserMemoryService: @@ -30,6 +320,7 @@ class UserMemoryService: def __init__(self): logger.info("UserMemoryService initialized") + self.neo4j_connector = Neo4jConnector() @staticmethod def _datetime_to_timestamp(dt: Optional[Any]) -> Optional[int]: @@ -592,6 +883,866 @@ class UserMemoryService: "failed": failed, "errors": errors + [{"error": f"批量处理失败: {str(e)}"}] } + + async def _get_title_and_type( + self, + summary_id: str, + end_user_id: str + ) -> Tuple[str, str]: + """ + 读取情景记忆的标题(title)和类型(type) + + 仅负责读取已存在的title和type,不进行生成 + title从name属性读取,type从memory_type属性读取 + + Args: + summary_id: Summary节点的ID + end_user_id: 终端用户ID (group_id) + + Returns: + (标题, 类型)元组,如果不存在则返回默认值 + """ + try: + # 查询Summary节点的name(作为title)和memory_type(作为type) + query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $summary_id AND s.group_id = $group_id + RETURN s.name AS title, s.memory_type AS type + """ + + result = await self.neo4j_connector.execute_query( + query, + summary_id=summary_id, + group_id=end_user_id + ) + + if not result or len(result) == 0: + logger.warning(f"未找到 summary_id={summary_id} 的节点") + return ("未知标题", "其他") + + record = result[0] + title = record.get("title") or "未命名" + episodic_type = record.get("type") or "其他" + + return (title, episodic_type) + + except Exception as e: + logger.error(f"读取标题和类型时出错: {str(e)}", exc_info=True) + return ("错误", "其他") + + @staticmethod + async def generate_title_and_type_for_summary( + content: str, + end_user_id: str + ) -> Tuple[str, str]: + """ + 为MemorySummary生成标题和类型(静态方法,用于创建节点时调用) + + 此方法应该在创建MemorySummary节点时调用,生成title和type + + Args: + content: Summary的内容文本 + end_user_id: 终端用户ID (group_id) + + Returns: + (标题, 类型)元组 + """ + from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt + import json + + # 定义有效的类型集合 + VALID_TYPES = { + "conversation", # 对话 + "project_work", # 项目/工作 + "learning", # 学习 + "decision", # 决策 + "important_event" # 重要事件 + } + DEFAULT_TYPE = "conversation" # 默认类型 + + try: + if not content: + logger.warning("content为空,无法生成标题和类型") + return ("空内容", DEFAULT_TYPE) + + # 1. 渲染Jinja2提示词模板 + prompt = await render_episodic_title_and_type_prompt(content) + + # 2. 调用LLM生成标题和类型 + llm_client = _get_llm_client_for_user(end_user_id) + messages = [ + {"role": "user", "content": prompt} + ] + + response = await llm_client.chat(messages=messages) + + # 3. 解析LLM响应 + content_response = response.content + if isinstance(content_response, list): + if len(content_response) > 0: + if isinstance(content_response[0], dict): + text = content_response[0].get('text', content_response[0].get('content', str(content_response[0]))) + full_response = str(text) + else: + full_response = str(content_response[0]) + else: + full_response = "" + elif isinstance(content_response, dict): + full_response = str(content_response.get('text', content_response.get('content', str(content_response)))) + else: + full_response = str(content_response) if content_response is not None else "" + + # 4. 解析JSON响应 + try: + # 尝试从响应中提取JSON + # 移除可能的markdown代码块标记 + json_str = full_response.strip() + if json_str.startswith("```json"): + json_str = json_str[7:] + if json_str.startswith("```"): + json_str = json_str[3:] + if json_str.endswith("```"): + json_str = json_str[:-3] + json_str = json_str.strip() + + result_data = json.loads(json_str) + title = result_data.get("title", "未知标题") + episodic_type_raw = result_data.get("type", DEFAULT_TYPE) + + # 5. 校验和归一化类型 + # 将类型转换为小写并去除空格 + episodic_type_normalized = str(episodic_type_raw).lower().strip() + + # 检查是否在有效类型集合中 + if episodic_type_normalized in VALID_TYPES: + episodic_type = episodic_type_normalized + else: + # 尝试映射常见的中文类型到英文 + type_mapping = { + "对话": "conversation", + "项目": "project_work", + "工作": "project_work", + "项目/工作": "project_work", + "学习": "learning", + "决策": "decision", + "重要事件": "important_event", + "事件": "important_event" + } + episodic_type = type_mapping.get(episodic_type_raw, DEFAULT_TYPE) + logger.warning( + f"LLM返回的类型 '{episodic_type_raw}' 不在有效集合中," + f"已归一化为 '{episodic_type}'" + ) + + logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}") + return (title, episodic_type) + + except json.JSONDecodeError: + logger.error(f"无法解析LLM响应为JSON: {full_response}") + return ("解析失败", DEFAULT_TYPE) + + except Exception as e: + logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True) + return ("错误", DEFAULT_TYPE) + + async def _extract_involved_objects( + self, + summary_id: str, + end_user_id: str + ) -> List[str]: + """ + 提取情景记忆涉及的前3个最重要实体 + + Args: + summary_id: Summary节点的ID + end_user_id: 终端用户ID (group_id) + + Returns: + 前3个实体的name属性列表 + """ + try: + # 查询Summary节点指向的Statement节点,再查询Statement指向的ExtractedEntity节点 + # 按activation_value降序排序,返回前3个 + query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $summary_id AND s.group_id = $group_id + MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) + MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity) + WHERE entity.activation_value IS NOT NULL + RETURN DISTINCT entity.name AS name, entity.activation_value AS activation + ORDER BY activation DESC + LIMIT 3 + """ + + result = await self.neo4j_connector.execute_query( + query, + summary_id=summary_id, + group_id=end_user_id + ) + + # 提取实体名称 + involved_objects = [record["name"] for record in result if record.get("name")] + + logger.info(f"成功提取 summary_id={summary_id} 的涉及对象: {involved_objects}") + + return involved_objects + + except Exception as e: + logger.error(f"提取涉及对象时出错: {str(e)}", exc_info=True) + return [] + + async def _extract_content_records( + self, + summary_id: str, + end_user_id: str + ) -> List[str]: + """ + 提取情景记忆的内容记录 + + Args: + summary_id: Summary节点的ID + end_user_id: 终端用户ID (group_id) + + Returns: + 所有Statement节点的statement属性内容列表 + """ + try: + # 查询Summary节点指向的所有Statement节点 + query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $summary_id AND s.group_id = $group_id + MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) + WHERE stmt.statement IS NOT NULL AND stmt.statement <> '' + RETURN stmt.statement AS statement + """ + + result = await self.neo4j_connector.execute_query( + query, + summary_id=summary_id, + group_id=end_user_id + ) + + # 提取statement内容 + content_records = [record["statement"] for record in result if record.get("statement")] + + logger.info(f"成功提取 summary_id={summary_id} 的内容记录,共 {len(content_records)} 条") + + return content_records + + except Exception as e: + logger.error(f"提取内容记录时出错: {str(e)}", exc_info=True) + return [] + + async def _extract_episodic_emotion( + self, + summary_id: str, + end_user_id: str + ) -> Optional[str]: + """ + 提取情景记忆的主要情绪 + + Args: + summary_id: Summary节点的ID + end_user_id: 终端用户ID (group_id) + + Returns: + 最大emotion_intensity对应的emotion_type,如果没有则返回None + """ + try: + # 查询Summary节点指向的所有Statement节点 + # 筛选具有emotion_type属性的节点 + # 按emotion_intensity降序排序,返回第一个 + query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $summary_id AND s.group_id = $group_id + MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) + WHERE stmt.emotion_type IS NOT NULL + AND stmt.emotion_intensity IS NOT NULL + RETURN stmt.emotion_type AS emotion_type, + stmt.emotion_intensity AS emotion_intensity + ORDER BY emotion_intensity DESC + LIMIT 1 + """ + + result = await self.neo4j_connector.execute_query( + query, + summary_id=summary_id, + group_id=end_user_id + ) + + # 提取emotion_type + if result and len(result) > 0: + emotion_type = result[0].get("emotion_type") + logger.info(f"成功提取 summary_id={summary_id} 的情绪: {emotion_type}") + return emotion_type + else: + logger.info(f"summary_id={summary_id} 没有情绪信息") + return None + + except Exception as e: + logger.error(f"提取情景记忆情绪时出错: {str(e)}", exc_info=True) + return None + + async def get_episodic_memory_overview( + self, + db: Session, + end_user_id: str, + time_range: str = "all", + episodic_type: str = "all", + title_keyword: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取情景记忆总览信息 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + time_range: 时间范围筛选 + episodic_type: 情景类型筛选 + title_keyword: 标题关键词(可选,用于模糊搜索) + """ + try: + logger.info( + f"开始查询 end_user_id={end_user_id} 的情景记忆总览, " + f"time_range={time_range}, episodic_type={episodic_type}, title_keyword={title_keyword}" + ) + + # 1. 先查询所有情景记忆的总数(不受筛选条件限制) + total_all_query = """ + MATCH (s:MemorySummary) + WHERE s.group_id = $group_id + RETURN count(s) AS total_all + """ + total_all_result = await self.neo4j_connector.execute_query( + total_all_query, + group_id=end_user_id + ) + total_all = total_all_result[0]["total_all"] if total_all_result else 0 + + # 2. 计算时间范围的起始时间戳 + time_filter = self._calculate_time_filter(time_range) + + # 3. 构建Cypher查询 + query = """ + MATCH (s:MemorySummary) + WHERE s.group_id = $group_id + """ + + # 添加时间范围过滤 + if time_filter: + query += " AND s.created_at >= $time_filter" + + # 添加标题关键词过滤(如果提供了title_keyword) + if title_keyword: + query += " AND toLower(s.name) CONTAINS toLower($title_keyword)" + + query += """ + RETURN elementId(s) AS id, + s.created_at AS created_at, + s.memory_type AS type, + s.name AS title + ORDER BY s.created_at DESC + """ + + params = {"group_id": end_user_id} + if time_filter: + params["time_filter"] = time_filter + if title_keyword: + params["title_keyword"] = title_keyword + + result = await self.neo4j_connector.execute_query(query, **params) + + # 4. 如果没有数据,返回空列表 + if not result: + logger.info(f"end_user_id={end_user_id} 没有情景记忆数据") + return { + "total": 0, + "total_all": total_all, + "episodic_memories": [] + } + + # 5. 对每个节点读取标题和类型,并应用类型筛选 + episodic_memories = [] + for record in result: + summary_id = record["id"] + created_at_str = record.get("created_at") + memory_type = record.get("type", "其他") + title = record.get("title") or "未命名" # 直接从查询结果获取标题 + + # 应用情景类型筛选 + if episodic_type != "all": + # 检查类型是否匹配 + # 注意:Neo4j 中存储的 memory_type 现在应该是英文格式(如 "conversation", "project_work") + # 但为了兼容旧数据,我们也支持中文格式的匹配 + type_mapping = { + "conversation": "对话", + "project_work": "项目/工作", + "learning": "学习", + "decision": "决策", + "important_event": "重要事件" + } + + # 获取对应的中文类型(用于兼容旧数据) + chinese_type = type_mapping.get(episodic_type) + + # 检查类型是否匹配(支持新的英文格式和旧的中文格式) + if memory_type != episodic_type and memory_type != chinese_type: + continue + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + episodic_memories.append({ + "id": summary_id, + "title": title, + "type": memory_type, + "created_at": created_at_timestamp + }) + + logger.info( + f"成功获取 end_user_id={end_user_id} 的情景记忆总览," + f"筛选后 {len(episodic_memories)} 条,总共 {total_all} 条" + ) + + return { + "total": len(episodic_memories), + "total_all": total_all, + "episodic_memories": episodic_memories + } + + except Exception as e: + logger.error(f"获取情景记忆总览时出错: {str(e)}", exc_info=True) + raise + + def _calculate_time_filter(self, time_range: str) -> Optional[str]: + """ + 根据时间范围计算过滤的起始时间 + + Args: + time_range: 时间范围 (all/today/this_week/this_month) + + Returns: + ISO格式的时间字符串,如果是"all"则返回None + """ + from datetime import datetime, timedelta + import pytz + + if time_range == "all": + return None + + # 获取当前时间(UTC) + now = datetime.now(pytz.UTC) + + if time_range == "today": + # 今天的开始时间(00:00:00) + start_time = now.replace(hour=0, minute=0, second=0, microsecond=0) + elif time_range == "this_week": + # 本周的开始时间(周一00:00:00) + days_since_monday = now.weekday() + start_time = (now - timedelta(days=days_since_monday)).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + elif time_range == "this_month": + # 本月的开始时间(1号00:00:00) + start_time = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + else: + return None + + # 返回ISO格式字符串 + return start_time.isoformat() + + async def get_episodic_memory_details( + self, + db: Session, + end_user_id: str, + summary_id: str + ) -> Dict[str, Any]: + """ + 获取单个情景记忆的详细信息 + + """ + try: + logger.info(f"开始查询 end_user_id={end_user_id}, summary_id={summary_id} 的情景记忆详情") + + # 1. 查询指定的MemorySummary节点 + query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $summary_id AND s.group_id = $group_id + RETURN elementId(s) AS id, s.created_at AS created_at + """ + + result = await self.neo4j_connector.execute_query( + query, + summary_id=summary_id, + group_id=end_user_id + ) + + # 2. 如果节点不存在,返回错误 + if not result or len(result) == 0: + logger.warning(f"未找到 summary_id={summary_id} 的情景记忆") + raise ValueError(f"情景记忆不存在: summary_id={summary_id}") + + # 3. 获取基本信息 + record = result[0] + created_at_str = record.get("created_at") + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + # 4. 调用_get_title_and_type读取标题和类型 + title, episodic_type = await self._get_title_and_type( + summary_id=summary_id, + end_user_id=end_user_id + ) + + # 5. 调用_extract_involved_objects提取涉及对象 + involved_objects = await self._extract_involved_objects( + summary_id=summary_id, + end_user_id=end_user_id + ) + + # 6. 调用_extract_content_records提取内容记录 + content_records = await self._extract_content_records( + summary_id=summary_id, + end_user_id=end_user_id + ) + + # 7. 调用_extract_episodic_emotion提取情绪 + emotion = await self._extract_episodic_emotion( + summary_id=summary_id, + end_user_id=end_user_id + ) + + # 8. 返回完整的详情信息 + details = { + "id": summary_id, + "created_at": created_at_timestamp, + "involved_objects": involved_objects, + "episodic_type": episodic_type, + "content_records": content_records, + "emotion": emotion + } + + logger.info(f"成功获取 summary_id={summary_id} 的情景记忆详情") + + return details + + except ValueError: + # 重新抛出ValueError,让Controller层处理 + raise + except Exception as e: + logger.error(f"获取情景记忆详情时出错: {str(e)}", exc_info=True) + raise + + async def get_explicit_memory_overview( + self, + db: Session, + end_user_id: str + ) -> Dict[str, Any]: + """ + 获取显性记忆总览信息 + + 返回两部分: + 1. 情景记忆(episodic_memories)- 来自MemorySummary节点 + 2. 语义记忆(semantic_memories)- 来自ExtractedEntity节点(is_explicit_memory=true) + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + + Returns: + { + "total": int, + "episodic_memories": [ + { + "id": str, + "title": str, + "content": str, + "created_at": int, + "emotion": Dict + } + ], + "semantic_memories": [ + { + "id": str, + "name": str, + "entity_type": str, + "core_definition": str, + "detailed_notes": str, + "created_at": int + } + ] + } + """ + try: + logger.info(f"开始查询 end_user_id={end_user_id} 的显性记忆总览(情景记忆+语义记忆)") + + # ========== 1. 查询情景记忆(MemorySummary节点) ========== + episodic_query = """ + MATCH (s:MemorySummary) + WHERE s.group_id = $group_id + RETURN elementId(s) AS id, + s.name AS title, + s.content AS content, + s.created_at AS created_at + ORDER BY s.created_at DESC + """ + + episodic_result = await self.neo4j_connector.execute_query( + episodic_query, + group_id=end_user_id + ) + + # 处理情景记忆数据 + episodic_memories = [] + if episodic_result: + for record in episodic_result: + summary_id = record["id"] + title = record.get("title") or "未命名" + content = record.get("content") or "" + created_at_str = record.get("created_at") + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + # 注意:总览接口不返回 emotion 字段 + episodic_memories.append({ + "id": summary_id, + "title": title, + "content": content, + "created_at": created_at_timestamp + }) + + # ========== 2. 查询语义记忆(ExtractedEntity节点) ========== + semantic_query = """ + MATCH (e:ExtractedEntity) + WHERE e.group_id = $group_id + AND e.is_explicit_memory = true + RETURN elementId(e) AS id, + e.name AS name, + e.entity_type AS entity_type, + e.description AS core_definition, + e.example AS detailed_notes, + e.created_at AS created_at + ORDER BY e.created_at DESC + """ + + semantic_result = await self.neo4j_connector.execute_query( + semantic_query, + group_id=end_user_id + ) + + # 处理语义记忆数据 + semantic_memories = [] + if semantic_result: + for record in semantic_result: + entity_id = record["id"] + name = record.get("name") or "未命名" + entity_type = record.get("entity_type") or "未分类" + core_definition = record.get("core_definition") or "" + created_at_str = record.get("created_at") + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + # 注意:总览接口不返回 detailed_notes 字段 + semantic_memories.append({ + "id": entity_id, + "name": name, + "entity_type": entity_type, + "core_definition": core_definition, + "created_at": created_at_timestamp + }) + + # ========== 3. 返回结果 ========== + total_count = len(episodic_memories) + len(semantic_memories) + + logger.info( + f"成功获取 end_user_id={end_user_id} 的显性记忆总览," + f"情景记忆={len(episodic_memories)} 条,语义记忆={len(semantic_memories)} 条," + f"总计 {total_count} 条" + ) + + return { + "total": total_count, + "episodic_memories": episodic_memories, + "semantic_memories": semantic_memories + } + + except Exception as e: + logger.error(f"获取显性记忆总览时出错: {str(e)}", exc_info=True) + raise + + async def get_explicit_memory_details( + self, + db: Session, + end_user_id: str, + memory_id: str + ) -> Dict[str, Any]: + """ + 获取显性记忆详情 + + 根据 memory_id 查询情景记忆或语义记忆的详细信息。 + 先尝试查询情景记忆,如果找不到再查询语义记忆。 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + memory_id: 记忆ID(可以是情景记忆或语义记忆的ID) + + Returns: + 情景记忆返回: + { + "memory_type": "episodic", + "title": str, + "content": str, + "emotion": Dict, + "created_at": int + } + + 语义记忆返回: + { + "memory_type": "semantic", + "name": str, + "core_definition": str, + "detailed_notes": str, + "created_at": int + } + + Raises: + ValueError: 当记忆不存在时 + """ + try: + logger.info(f"开始查询显性记忆详情: end_user_id={end_user_id}, memory_id={memory_id}") + + # ========== 1. 先尝试查询情景记忆 ========== + episodic_query = """ + MATCH (s:MemorySummary) + WHERE elementId(s) = $memory_id AND s.group_id = $group_id + RETURN s.name AS title, + s.content AS content, + s.created_at AS created_at + """ + + episodic_result = await self.neo4j_connector.execute_query( + episodic_query, + memory_id=memory_id, + group_id=end_user_id + ) + + if episodic_result and len(episodic_result) > 0: + record = episodic_result[0] + title = record.get("title") or "未命名" + content = record.get("content") or "" + created_at_str = record.get("created_at") + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + # 获取情绪信息 + emotion = await self._extract_episodic_emotion( + summary_id=memory_id, + end_user_id=end_user_id + ) + + logger.info(f"成功获取情景记忆详情: memory_id={memory_id}") + return { + "memory_type": "episodic", + "title": title, + "content": content, + "emotion": emotion, + "created_at": created_at_timestamp + } + + # ========== 2. 如果不是情景记忆,尝试查询语义记忆 ========== + semantic_query = """ + MATCH (e:ExtractedEntity) + WHERE elementId(e) = $memory_id + AND e.group_id = $group_id + AND e.is_explicit_memory = true + RETURN e.name AS name, + e.description AS core_definition, + e.example AS detailed_notes, + e.created_at AS created_at + """ + + semantic_result = await self.neo4j_connector.execute_query( + semantic_query, + memory_id=memory_id, + group_id=end_user_id + ) + + if semantic_result and len(semantic_result) > 0: + record = semantic_result[0] + name = record.get("name") or "未命名" + core_definition = record.get("core_definition") or "" + detailed_notes = record.get("detailed_notes") or "" + created_at_str = record.get("created_at") + + # 转换时间戳 + created_at_timestamp = None + if created_at_str: + try: + from datetime import datetime + dt_object = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + created_at_timestamp = int(dt_object.timestamp() * 1000) + except (ValueError, TypeError, AttributeError) as e: + logger.warning(f"无法解析时间戳: {created_at_str}, error={str(e)}") + + logger.info(f"成功获取语义记忆详情: memory_id={memory_id}") + return { + "memory_type": "semantic", + "name": name, + "core_definition": core_definition, + "detailed_notes": detailed_notes, + "created_at": created_at_timestamp + } + + # ========== 3. 两种记忆都找不到 ========== + logger.warning(f"记忆不存在: memory_id={memory_id}, end_user_id={end_user_id}") + raise ValueError(f"记忆不存在: memory_id={memory_id}") + + except ValueError: + # 重新抛出 ValueError(记忆不存在) + raise + except Exception as e: + logger.error(f"获取显性记忆详情时出错: {str(e)}", exc_info=True) + raise # 独立的分析函数 @@ -601,7 +1752,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> 生成记忆洞察报告(四个维度) 这个函数包含完整的业务逻辑: - 1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联) + 1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联) 2. 使用 Jinja2 模板渲染提示词 3. 调用 LLM 生成四个维度的自然语言报告 4. 解析并返回四个部分 @@ -620,7 +1771,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt import re - insight = MemoryInsight(end_user_id) + insight = MemoryInsightHelper(end_user_id) try: # 1. 并行获取三个维度的数据 @@ -722,7 +1873,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, 生成用户摘要(包含四个部分) 这个函数包含完整的业务逻辑: - 1. 使用 UserSummary 工具类获取基础数据(实体、语句) + 1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句) 2. 使用 prompt_utils 渲染提示词 3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结 @@ -737,20 +1888,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, "one_sentence": str } """ - from app.core.memory.analytics.user_summary import UserSummary from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt import re - # 创建 UserSummary 实例 - user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) + # 创建 UserSummaryHelper 实例 + user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) try: # 1) 收集上下文数据 - entities = await user_summary_tool._get_top_entities(limit=40) - statements = await user_summary_tool._get_recent_statements(limit=100) + entities = await user_summary_tool.get_top_entities(limit=40) + statements = await user_summary_tool.get_recent_statements(limit=100) entity_lines = [f"{name} ({freq})" for name, freq in entities][:20] - statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20] + statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20] # 2) 使用 prompt_utils 渲染提示词 user_prompt = await render_user_summary_prompt( @@ -794,6 +1944,28 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, core_values = core_values_match.group(1).strip() if core_values_match else "" one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else "" + # 6) 清理可能包含的反思内容(防御性编程) + # 如果 LLM 仍然输出了反思内容,在这里过滤掉 + def clean_reflection_content(text: str) -> str: + """移除可能包含的反思内容""" + if not text: + return text + # 移除 "---" 之后的所有内容(通常是反思部分的开始) + if '---' in text: + text = text.split('---')[0].strip() + # 移除 "**Step" 开头的内容 + if '**Step' in text: + text = text.split('**Step')[0].strip() + # 移除 "Self-Review" 相关内容 + if 'Self-Review' in text or 'self-review' in text: + text = re.sub(r'[\-\*]*\s*Self-Review.*$', '', text, flags=re.IGNORECASE | re.DOTALL).strip() + return text + + user_summary = clean_reflection_content(user_summary) + personality = clean_reflection_content(personality) + core_values = clean_reflection_content(core_values) + one_sentence = clean_reflection_content(one_sentence) + return { "user_summary": user_summary, "personality": personality, @@ -1078,7 +2250,7 @@ async def analytics_graph_data( "group_id": end_user_id, "limit": limit } - + # 执行节点查询 node_results = await _neo4j_connector.execute_query(node_query, **node_params) @@ -1093,7 +2265,7 @@ async def analytics_graph_data( node_props = record["properties"] # 根据节点类型提取需要的属性字段 - filtered_props = _extract_node_properties(node_label, node_props) + filtered_props = await _extract_node_properties(node_label, node_props,node_id) # 直接使用数据库中的 caption,如果没有则使用节点类型作为默认值 caption = filtered_props.get("caption", node_label) @@ -1199,7 +2371,7 @@ async def analytics_graph_data( # 辅助函数 -def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str, Any]: +async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]: """ 根据节点类型提取需要的属性字段 @@ -1214,8 +2386,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str field_whitelist = { "Dialogue": ["content", "created_at"], "Chunk": ["content", "created_at"], - "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], - "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], + "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], + "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 } @@ -1226,7 +2398,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str if not allowed_fields: # 对于未定义的节点类型,只返回基本字段 allowed_fields = ["name", "created_at", "caption"] - + count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;""" + node_results = await (_neo4j_connector.execute_query(count_neo4j)) # 提取白名单中的字段 filtered_props = {} for field in allowed_fields: @@ -1234,7 +2407,8 @@ def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str value = properties[field] # 清理 Neo4j 特殊类型 filtered_props[field] = _clean_neo4j_value(value) - + filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0] + print(filtered_props) return filtered_props diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d96efdf7..68d6279b 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -17,6 +17,7 @@ from app.core.workflow.validator import validate_workflow_config from app.db import get_db, get_db_context from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.end_user_repository import EndUserRepository +from app.services.multi_agent_service import convert_uuids_to_str from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -364,7 +365,7 @@ class WorkflowService: execution.status = status if output_data is not None: - execution.output_data = output_data + execution.output_data = convert_uuids_to_str(output_data) if error_message is not None: execution.error_message = error_message if error_node_id is not None: diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index 0843744d..2cf0f9b1 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -68,12 +68,7 @@ edges: label: 完成 # 变量定义 -variables: - - name: user_question - type: string - required: true - description: 用户的问题 - default: "" +variables: [] # 执行配置 execution_config: diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 4fe692c1..834d22af 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -5,13 +5,66 @@ Utility functions for converting between dict and model objects for different ap """ import uuid -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Union from datetime import datetime -from app.models import AppRelease +from app.models import AppRelease, WorkflowConfig from app.models.agent_app_config_model import AgentConfig from app.models.multi_agent_model import MultiAgentConfig + +def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]: + """将 ModelParameters 对象转换为字典 + + Args: + model_parameters: ModelParameters 对象、字典或 None + + Returns: + 字典格式的模型参数,如果输入为 None 则返回 None + """ + if model_parameters is None: + return None + + if isinstance(model_parameters, dict): + return model_parameters + + # Pydantic v2 + if hasattr(model_parameters, 'model_dump'): + return model_parameters.model_dump() + + # Pydantic v1 + if hasattr(model_parameters, 'dict'): + return model_parameters.dict() + + # 其他情况尝试转换 + try: + return dict(model_parameters) + except (TypeError, ValueError): + return None + + +def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]: + """将字典转换为 ModelParameters 对象 + + Args: + data: 字典格式的模型参数或 None + + Returns: + ModelParameters 对象,如果输入为 None 则返回 None + """ + if data is None: + return None + + from app.schemas import ModelParameters + + if isinstance(data, ModelParameters): + return data + + if isinstance(data, dict): + return ModelParameters(**data) + + return None + class AgentConfigProxy: """Proxy class for AgentConfig (legacy compatibility)""" @@ -28,7 +81,7 @@ class AgentConfigProxy: def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: config_dict = release.config - + agent_config = AgentConfig( app_id=release.app_id, system_prompt=config_dict.get("system_prompt"), @@ -45,10 +98,10 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: config_dict = release.config - + agent_config = MultiAgentConfig( - app_id=release.app_id, + app_id=release.app_id, default_model_config_id=release.default_model_config_id, model_parameters=config_dict.get("model_parameters"), master_agent_id=config_dict.get("master_agent_id"), @@ -58,11 +111,29 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: routing_rules=config_dict.get("routing_rules"), execution_config=config_dict.get("execution_config", {}), aggregation_strategy=config_dict.get("aggregation_strategy", "merge"), - + ) return agent_config +def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig: + + config_dict = release.config + + + config = WorkflowConfig( + id=release.id, + app_id=release.app_id, + nodes=config_dict.get("nodes", []), + edges=config_dict.get("edges", []), + variables=config_dict.get("variables", []), + execution_config=config_dict.get("execution_config", {}), + triggers=config_dict.get("triggers", []) + + ) + + return config + def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None): """Convert dict to MultiAgentConfig model object diff --git a/api/migrations/versions/52ebaf4ad3fb_202601081915.py b/api/migrations/versions/52ebaf4ad3fb_202601081915.py new file mode 100644 index 00000000..747a4c9a --- /dev/null +++ b/api/migrations/versions/52ebaf4ad3fb_202601081915.py @@ -0,0 +1,42 @@ +"""202601081915 + +Revision ID: 52ebaf4ad3fb +Revises: a959b201c507 +Create Date: 2026-01-08 19:15:43.830726 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '52ebaf4ad3fb' +down_revision: Union[str, None] = 'a959b201c507' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('conversation_details', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('conversation_id', sa.UUID(), nullable=True), + sa.Column('theme', sa.String(), nullable=True, comment='会话主题'), + sa.Column('theme_intro', sa.String(), nullable=True, comment='主题介绍'), + sa.Column('summary', sa.String(), nullable=True, comment='会话摘要'), + sa.Column('takeaways', sa.JSON(), nullable=True, comment='会话要点'), + sa.Column('info_score', sa.Integer(), nullable=True, comment='会话信息量评分'), + sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_conversation_details_id'), 'conversation_details', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_conversation_details_id'), table_name='conversation_details') + op.drop_table('conversation_details') + # ### end Alembic commands ### diff --git a/api/migrations/versions/793c31683aa5_202601121530.py b/api/migrations/versions/793c31683aa5_202601121530.py new file mode 100644 index 00000000..2cbd95d0 --- /dev/null +++ b/api/migrations/versions/793c31683aa5_202601121530.py @@ -0,0 +1,32 @@ +"""202601121530 + +Revision ID: 793c31683aa5 +Revises: 52ebaf4ad3fb +Create Date: 2026-01-12 15:29:03.135322 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '793c31683aa5' +down_revision: Union[str, None] = '52ebaf4ad3fb' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('conversation_details', sa.Column('question', sa.JSON(), nullable=True, comment='用户问题')) + op.drop_column('conversation_details', 'theme_intro') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('conversation_details', sa.Column('theme_intro', sa.VARCHAR(), autoincrement=False, nullable=True, comment='主题介绍')) + op.drop_column('conversation_details', 'question') + # ### end Alembic commands ### diff --git a/api/migrations/versions/8372101eda28_202601071400.py b/api/migrations/versions/8372101eda28_202601071400.py new file mode 100644 index 00000000..43ca9056 --- /dev/null +++ b/api/migrations/versions/8372101eda28_202601071400.py @@ -0,0 +1,30 @@ +"""202601071400 + +Revision ID: 8372101eda28 +Revises: 6064f41faac6 +Create Date: 2026-01-07 14:00:14.700994 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8372101eda28' +down_revision: Union[str, None] = '6064f41faac6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('data_config', 'llm') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('data_config', sa.Column('llm', sa.VARCHAR(), autoincrement=False, nullable=True, comment='LLM模型配置ID')) + # ### end Alembic commands ### diff --git a/api/migrations/versions/a959b201c507_202601081520.py b/api/migrations/versions/a959b201c507_202601081520.py new file mode 100644 index 00000000..02331592 --- /dev/null +++ b/api/migrations/versions/a959b201c507_202601081520.py @@ -0,0 +1,48 @@ +"""202601081520 + +Revision ID: a959b201c507 +Revises: c6d4afa27bf0 +Create Date: 2026-01-08 15:20:29.742666 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a959b201c507' +down_revision: Union[str, None] = 'c6d4afa27bf0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('forgetting_cycle_history', + sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('execution_time', sa.DateTime(), nullable=False, comment='执行时间'), + sa.Column('merged_count', sa.Integer(), nullable=True, comment='本次成功融合的节点对数'), + sa.Column('failed_count', sa.Integer(), nullable=True, comment='本次融合失败的节点对数'), + sa.Column('average_activation_value', sa.Float(), nullable=True, comment='平均激活值'), + sa.Column('total_nodes', sa.Integer(), nullable=True, comment='总节点数'), + sa.Column('low_activation_nodes', sa.Integer(), nullable=True, comment='低于遗忘阈值的节点总数(包含已融合、失败和待处理的)'), + sa.Column('duration_seconds', sa.Float(), nullable=True, comment='执行耗时(秒)'), + sa.Column('trigger_type', sa.String(length=50), nullable=True, comment='触发类型: manual/scheduled'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_end_user_time', 'forgetting_cycle_history', ['end_user_id', 'execution_time'], unique=False) + op.create_index('idx_execution_time', 'forgetting_cycle_history', ['execution_time'], unique=False) + op.create_index(op.f('ix_forgetting_cycle_history_id'), 'forgetting_cycle_history', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_forgetting_cycle_history_id'), table_name='forgetting_cycle_history') + op.drop_index('idx_execution_time', table_name='forgetting_cycle_history') + op.drop_index('idx_end_user_time', table_name='forgetting_cycle_history') + op.drop_table('forgetting_cycle_history') + # ### end Alembic commands ### diff --git a/api/migrations/versions/c6d4afa27bf0_202601071800.py b/api/migrations/versions/c6d4afa27bf0_202601071800.py new file mode 100644 index 00000000..7d048f43 --- /dev/null +++ b/api/migrations/versions/c6d4afa27bf0_202601071800.py @@ -0,0 +1,88 @@ +"""202601071800 + +Revision ID: c6d4afa27bf0 +Revises: 8372101eda28 +Create Date: 2026-01-07 17:59:23.032323 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'c6d4afa27bf0' +down_revision: Union[str, None] = '8372101eda28' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('memory_long_term', + sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_memory_long_term_created_at'), 'memory_long_term', ['created_at'], unique=False) + op.create_index(op.f('ix_memory_long_term_end_user_id'), 'memory_long_term', ['end_user_id'], unique=False) + op.create_index(op.f('ix_memory_long_term_id'), 'memory_long_term', ['id'], unique=False) + op.create_table('memory_short_term', + sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('messages', sa.Text(), nullable=False, comment='用户消息内容'), + sa.Column('aimessages', sa.Text(), nullable=True, comment='AI回复消息内容'), + sa.Column('search_switch', sa.String(length=50), nullable=True, comment='搜索开关状态'), + sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_memory_short_term_created_at'), 'memory_short_term', ['created_at'], unique=False) + op.create_index(op.f('ix_memory_short_term_end_user_id'), 'memory_short_term', ['end_user_id'], unique=False) + op.create_index(op.f('ix_memory_short_term_id'), 'memory_short_term', ['id'], unique=False) + op.create_table('memory_perceptual', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('end_user_id', sa.UUID(), nullable=True), + sa.Column('perceptual_type', sa.Integer(), nullable=False, comment='感知类型'), + sa.Column('storage_service', sa.Integer(), nullable=True, comment='存储服务类型'), + sa.Column('file_path', sa.String(), nullable=False, comment='文件路径'), + sa.Column('file_name', sa.String(), nullable=False, comment='文件名称'), + sa.Column('file_ext', sa.String(), nullable=False, comment='文件后缀名'), + sa.Column('summary', sa.String(), nullable=True, comment='摘要'), + sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='元信息'), + sa.Column('created_time', sa.DateTime(), nullable=True, comment='创建时间'), + sa.ForeignKeyConstraint(['end_user_id'], ['end_users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_memory_perceptual_end_user_id'), 'memory_perceptual', ['end_user_id'], unique=False) + op.create_index(op.f('ix_memory_perceptual_perceptual_type'), 'memory_perceptual', ['perceptual_type'], unique=False) + op.alter_column('multi_agent_configs', 'orchestration_mode', + existing_type=sa.VARCHAR(length=20), + comment='协作模式: collaboration(协作)| supervisor(监督)', + existing_comment='协作模式: sequential|parallel|conditional|loop', + existing_nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('multi_agent_configs', 'orchestration_mode', + existing_type=sa.VARCHAR(length=20), + comment='协作模式: sequential|parallel|conditional|loop', + existing_comment='协作模式: collaboration(协作)| supervisor(监督)', + existing_nullable=False) + op.drop_index(op.f('ix_memory_perceptual_perceptual_type'), table_name='memory_perceptual') + op.drop_index(op.f('ix_memory_perceptual_end_user_id'), table_name='memory_perceptual') + op.drop_table('memory_perceptual') + op.drop_index(op.f('ix_memory_short_term_id'), table_name='memory_short_term') + op.drop_index(op.f('ix_memory_short_term_end_user_id'), table_name='memory_short_term') + op.drop_index(op.f('ix_memory_short_term_created_at'), table_name='memory_short_term') + op.drop_table('memory_short_term') + op.drop_index(op.f('ix_memory_long_term_id'), table_name='memory_long_term') + op.drop_index(op.f('ix_memory_long_term_end_user_id'), table_name='memory_long_term') + op.drop_index(op.f('ix_memory_long_term_created_at'), table_name='memory_long_term') + op.drop_table('memory_long_term') + # ### end Alembic commands ### diff --git a/web/package.json b/web/package.json index 9d157982..e28e8b56 100644 --- a/web/package.json +++ b/web/package.json @@ -30,6 +30,7 @@ "dayjs": "^1.11.18", "echarts": "^5.6.0", "echarts-for-react": "^3.0.2", + "echarts-wordcloud": "^2.1.0", "i18next": "^25.6.0", "js-yaml": "^4.1.1", "lexical": "^0.39.0", diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 750f559c..47177136 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -134,6 +134,77 @@ export const getEmotionSuggestions = (group_id: string) => { export const analyticsRefresh = (end_user_id: string) => { return request.post('/memory-storage/analytics/generate_cache', { end_user_id }) } +// 遗忘 +export const getForgetStats = (group_id: string) => { + return request.get(`/memory/forget/stats`, { group_id }) +} +// 隐性记忆-偏好 +export const getImplicitPreferences = (end_user_id: string) => { + return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) +} +// 隐性记忆-核心特质 +export const getImplicitPortrait = (end_user_id: string) => { + return request.get(`/memory/implicit-memory/portrait/${end_user_id}`) +} +// 隐性记忆-兴趣领域分布 +export const getImplicitInterestAreas = (end_user_id: string) => { + return request.get(`/memory/implicit-memory/interest-areas/${end_user_id}`) +} +// 隐性记忆-用户习惯分析 +export const getImplicitHabits = (end_user_id: string) => { + return request.get(`/memory/implicit-memory/habits/${end_user_id}`) +} +// 短期记忆 +export const getShortTerm = (end_user_id: string) => { + return request.get(`/memory/short/short_term`, { end_user_id }) +} +// 感知记忆-视觉记忆 +export const getPerceptualLastVisual = (end_user: string) => { + return request.get(`/memory/perceptual/${end_user}/last_visual`) +} +// 感知记忆-音频记忆 +export const getPerceptualLastListen = (end_user: string) => { + return request.get(`/memory/perceptual/${end_user}/last_listen`) +} +// 感知记忆-文本记忆 +export const getPerceptualLastText = (end_user: string) => { + return request.get(`/memory/perceptual/${end_user}/last_text`) +} +// 感知记忆-感知记忆时间线 +export const getPerceptualTimeline = (end_user: string) => { + return request.get(`/memory/perceptual/${end_user}/timeline`) +} +// 情景记忆-总览 +export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => { + return request.post(`/memory-storage/classifications/episodic-memory`, data) +} +export const getEpisodicDetail = (data: { end_user_id: string; summary_id: string; } ) => { + return request.post(`/memory-storage/classifications/episodic-memory-details`, data) +} +// 关系演化 +export const getRelationshipEvolution = (data: { id: string; label: string; } ) => { + return request.get(`/memory-storage/memory_space/relationship_evolution`, data) +} +// 共同记忆时间线 +export const getTimelineMemories = (data: { id: string; label: string; }) => { + return request.get(`/memory-storage/memory_space/timeline_memories`, data) +} +export const getExplicitMemory = (end_user_id: string) => { + return request.post(`/memory-storage/classifications/explicit-memory`, { end_user_id }) +} +export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => { + return request.post(`/memory-storage/classifications/explicit-memory-details`, data) +} +export const getConversations = (end_user: string) => { + return request.get(`/memory/work/${end_user}/conversations`) +} +export const getConversationMessages = (end_user: string, conversation_id: string) => { + return request.get(`/memory/work/${end_user}/messages`, { conversation_id }) +} +export const getConversationDetail = (end_user: string, conversation_id: string) => { + return request.get(`/memory/work/${end_user}/detail`, { conversation_id }) +} + /*************** end 用户记忆 相关接口 ******************************/ diff --git a/web/src/api/prompt.ts b/web/src/api/prompt.ts index 77ea1271..526f50ac 100644 --- a/web/src/api/prompt.ts +++ b/web/src/api/prompt.ts @@ -1,5 +1,6 @@ import { request } from '@/utils/request' import type { AiPromptForm } from '@/views/ApplicationConfig/types' +import { handleSSE, type SSEMessage } from '@/utils/stream' export const createPromptSessions = () => { return request.post(`/prompt/sessions`) @@ -7,6 +8,6 @@ export const createPromptSessions = () => { export const getPrompt = (session_id: string) => { return request.get(`/prompt/sessions/${session_id}`) } -export const updatePromptMessages = (session_id: string, data: AiPromptForm) => { - return request.post(`/prompt/sessions/${session_id}/messages`, data) +export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => { + return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage) } \ No newline at end of file diff --git a/web/src/assets/images/empty/pageLoading.png b/web/src/assets/images/empty/pageLoading.png new file mode 100644 index 00000000..36e0e32b Binary files /dev/null and b/web/src/assets/images/empty/pageLoading.png differ diff --git a/web/src/assets/images/userMemory/arrow_right_hover.svg b/web/src/assets/images/userMemory/arrow_right_hover.svg new file mode 100644 index 00000000..0fed7c6b --- /dev/null +++ b/web/src/assets/images/userMemory/arrow_right_hover.svg @@ -0,0 +1,14 @@ + + + 编组 5 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/detail_empty.png b/web/src/assets/images/userMemory/detail_empty.png index 67666235..3cada856 100644 Binary files a/web/src/assets/images/userMemory/detail_empty.png and b/web/src/assets/images/userMemory/detail_empty.png differ diff --git a/web/src/assets/images/userMemory/shortTerm.png b/web/src/assets/images/userMemory/shortTerm.png new file mode 100644 index 00000000..37a880ec Binary files /dev/null and b/web/src/assets/images/userMemory/shortTerm.png differ diff --git a/web/src/assets/images/userMemory/up_border.svg b/web/src/assets/images/userMemory/up_border.svg new file mode 100644 index 00000000..a7fe9978 --- /dev/null +++ b/web/src/assets/images/userMemory/up_border.svg @@ -0,0 +1,14 @@ + + + 下拉备份 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/view.svg b/web/src/assets/images/userMemory/view.svg new file mode 100644 index 00000000..642841ae --- /dev/null +++ b/web/src/assets/images/userMemory/view.svg @@ -0,0 +1,19 @@ + + + 查看 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/userMemory/view_hover.svg b/web/src/assets/images/userMemory/view_hover.svg new file mode 100644 index 00000000..642841ae --- /dev/null +++ b/web/src/assets/images/userMemory/view_hover.svg @@ -0,0 +1,19 @@ + + + 查看 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory-read.png b/web/src/assets/images/workflow/memory-read.png new file mode 100644 index 00000000..4b0cdc1d Binary files /dev/null and b/web/src/assets/images/workflow/memory-read.png differ diff --git a/web/src/assets/images/workflow/memory-write.png b/web/src/assets/images/workflow/memory-write.png new file mode 100644 index 00000000..83a50fd4 Binary files /dev/null and b/web/src/assets/images/workflow/memory-write.png differ diff --git a/web/src/components/Empty/PageLoading.tsx b/web/src/components/Empty/PageLoading.tsx new file mode 100644 index 00000000..df5041d9 --- /dev/null +++ b/web/src/components/Empty/PageLoading.tsx @@ -0,0 +1,16 @@ +import { useTranslation } from 'react-i18next' +import LoadingIcon from '@/assets/images/empty/pageLoading.png' +import Empty from './index' +const PageLoading = ({ size = [240, 210] }: { size?: number | number[] }) => { + const { t } = useTranslation() + return ( + + ) +} +export default PageLoading; \ No newline at end of file diff --git a/web/src/components/StatusTag/index.tsx b/web/src/components/StatusTag/index.tsx index 70b720b5..273fff1b 100644 --- a/web/src/components/StatusTag/index.tsx +++ b/web/src/components/StatusTag/index.tsx @@ -3,23 +3,27 @@ import { Tag } from 'antd'; import clsx from 'clsx'; interface StatusTagProps { - status: 'success' | 'error' | 'warning', + status: 'success' | 'error' | 'warning' | 'default' | 'lightBlue' | 'purple', text: string; } const Colors = { success: 'rb:bg-[#369F21]', error: 'rb:bg-[#FF5D34]', warning: 'rb:bg-[#FF8A4C]', + default: 'rb:bg-[#155EEF]', + lightBlue: 'rb:bg-[#4DA8FF]', + purple: 'rb:bg-[#9C6FFF]' } const StatusTag: FC = ({ status, text }) => { + console.log('status', status) return ( - - + + { text } diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index b9c69e25..408c9919 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1147,10 +1147,10 @@ export const en = { promptEmpty: 'Describe your use case on the left, and the orchestration preview will be displayed here.', master: 'Supervisor Mode', - master_agent: 'Supervisor Mode', - master_agentDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.', - handoffs: 'Collaboration Mode', - handoffsDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.', + supervisor: 'Supervisor Mode', + supervisorDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.', + collaboration: 'Collaboration Mode', + collaborationDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.', masterConfig: 'Supervisor Configuration', orchestrationMode: 'Task Assignment Strategy', conditional: 'Intelligent Assignment', @@ -1160,6 +1160,8 @@ export const en = { merge: 'Complete Aggregation', vote: 'Key Information Extraction', priority: 'Structured Integration', + addTool: 'Add Tool', + tool: 'Tool', }, userMemory: { userMemory: 'User Memory', @@ -1204,10 +1206,6 @@ export const en = { nodeStatistics: 'Memory Classification', total: 'Total', - Chunk: 'Long-term Memory', - MemorySummary: 'Episodic Memory', - Statement: 'Emotional Memory', - ExtractedEntity: 'Short-term Memory', PERCEPTUAL_MEMORY: 'Perceptual Memory', WORKING_MEMORY: 'Working Memory', @@ -1217,6 +1215,7 @@ export const en = { IMPLICIT_MEMORY: 'Implicit Memory', EMOTIONAL_MEMORY: 'Emotional Memory', EPISODIC_MEMORY: 'Episodic Memory', + FORGETTING_MANAGEMENT: 'Forgetting Management', endUserProfile: 'Core Profile', editEndUserProfile: 'Edit', @@ -1234,6 +1233,33 @@ export const en = { key_findings: 'Key Findings', behavior_pattern: 'Behavior Pattern', growth_trajectory: 'Growth Trajectory', + personality: 'Personality Traits', + core_values: 'Core Values', + + Statement_emotion_keywords: 'Emotion Keywords', + Statement_emotion_type: 'Emotion Type', + Statement_emotion_subject: 'Emotion Subject', + Statement_importance_score: 'Importance Score', + + ExtractedEntity_description: 'Description', + ExtractedEntity_name: 'Content', + ExtractedEntity_entity_type: 'Type', + ExtractedEntity_created_at: 'Created At', + ExtractedEntity_aliases: 'Aliases', + ExtractedEntity_connect_strngth: 'Connection Strength', + ExtractedEntity_importance_score: 'Importance Score', + + associative_memory: 'Associative Memory', + unix: 'items', + completeMemory: 'Complete Memory', + relationshipEvolution: 'Relationship Evolution', + timelineMemories: 'Shared Memory Timeline', + emotionLine: 'Emotion Changes Over Time', + interaction: 'Interaction Frequency & Relationship Stages', + timelines_memory: 'All', + MemorySummary: 'Long-term Accumulation', + Statement: 'Emotional Memory', + ExtractedEntity: 'Episodic Memory', }, space: { createSpace: 'Create Space', @@ -1809,12 +1835,20 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re "not_contains": 'Does Not Contain', "startwith": 'Starts With', "endwith": 'Ends With', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": 'Equals', + "ne": 'Not Equals', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": 'Is', + "ne": 'Is Not', + }, else_desc: 'Used to define the logic that should be executed when the if condition is not met.' }, 'http-request': { @@ -1839,6 +1873,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re status_code: 'Status Code', max_attempts: 'Max Retry Attempts', retry_interval: 'Retry Interval', + errorBranch: 'Error Branch', }, 'jinja-render': { template: 'Code', @@ -1855,12 +1890,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re loop: { cycle_vars: 'Loop Variables', condition: 'Loop Termination Condition', + max_loop: 'Maximum Loop Count', }, assigner: { assignments: 'Variables', - cover: 'Overwrite', + cover: 'Override', assign: 'Set', - clear: 'Clear' + clear: 'Clear', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: 'Input Variable', @@ -1963,6 +2003,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re }, statementDetail: { wordCloud: 'Emotion Distribution Analysis', + totalCount: 'Sample Count', pieces: 'items', emotionTags: 'High-Frequency Emotion Keywords', joy: 'Joy', @@ -2147,5 +2188,124 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re orderPayInfo: 'Payment Information', create_time: 'Creation Time', }, + forgetDetail: { + title: 'The forgetting management system helps AI intelligently manage memory lifecycle by automatically identifying low-value memories, setting forgetting strategies, and executing regular cleanup to optimize memory storage space and improve retrieval efficiency.', + overviewTitle: 'Core Metrics Overview', + totalMemory: 'Total Memory', + MemoryHealth: 'Memory Health', + riskOfForgetting: 'Forgetting Risk', + statement_count: 'Statements', + entity_count: 'Entities', + summary_count: 'Summaries', + chunk_count: 'Chunks', + healthStatus: 'Health Status', + average: 'Average Activation Value', + threshold: 'Threshold Reference:', + unhealthy: 'Unhealthy', + healthy: 'Healthy', + low_nodes: 'Low Activation Nodes', + memoryHealthVisualization: 'Memory Health Visualization', + activationValueDistribution: 'Activation Value Distribution', + forgettingTrend: 'Forgetting Trend (Last 7 Days)', + + nodes_without_activation: 'Observation Zone', + low_activation_nodes: 'Forgetting Zone', + health_nodes: 'Healthy Zone', + average_activation: 'Average Activation Value', + merged_count: 'Daily Merged Node Count', + + pending_nodes: 'Risk Node Forgetting Pool', + content_summary: 'Content Summary', + node_type: 'Node Type', + last_access_time: 'Last Activation Time', + activation_value: 'Current Activation Value', + }, + episodicDetail: { + title: 'Record every important scene you have truly experienced', + total_all: 'Total Episodic Memories', + all: "All", + today: 'Today', + this_week: 'This Week', + this_month: 'This Month', + conversation: "Conversation", + project_work: "Project/Work", + learning: "Learning", + decision: "Decision", + important_event: "Important Event", + titleKeywordPlaceholder: 'Search episode title or content', + curResult: 'Current Filter Results', + unix: 'items', + created: 'Occurrence Time', + episodic_type: 'Episode Type', + involved_objects: 'Involved Objects', + content_records: 'Episode Content Records', + emotion: 'Emotion and State Records', + }, + implicitDetail: { + title: 'The invisible forces that shaped me', + preferences: 'My Subconscious Preferences', + preferencesDetail: 'Association Network', + portraitTitle: 'My Subconscious Portrait', + portraitSubTitle: 'Personalized insights generated by AI based on your preference tags', + portrait: 'Core Traits', + aesthetic: 'Aesthetic Driven', + creativity: 'Creative Thinking', + literature: 'Cultural Sensitivity', + technology: 'Technology Affinity', + interestAreas: 'Interest Area Distribution', + art: 'Art & Design', + music: 'Music & Culture', + tech: 'Technology & Future', + lifestyle: 'Lifestyle', + habits: 'User Habit Analysis', + habitsSubTitle: 'Habit characteristics identified based on your behavior patterns', + context_details: 'Preference Details', + supporting_evidence: 'Preference Source', + specific_examples: 'Source', + }, + shortTermDetail: { + title: 'Short-term memory is the "workbench" of the AI system, connecting instant conversations with long-term knowledge bases. Through real-time capture, deep retrieval, intelligent extraction and filtering transformation, temporary unstructured information is converted into valuable long-term knowledge.', + retrieval_number: 'Retrieval Count', + entity: 'Extracted Entities', + long_term_number: 'Long-term Candidates', + shortTermTitle: 'Deep Retrieval & Extended Answer Area', + shortTermSubTitle: 'Stores deep information retrieval performed to answer questions and the extended answers generated from it, including original questions, retrieved information, and generated answers.', + longTermTitle: 'Long-term Memory Candidate Pool', + longTermSubTitle: 'Aggregates short-term memory, filters and prepares content for storage in long-term memory. This is the "transfer station" and "filter" from short-term to long-term memory.', + answer: 'Answer', + query: 'Question', + noAnswer: 'No reply yet', + }, + perceptualDetail: { + last_visual: 'Visual Perception Stream', + last_listen: 'Auditory Perception Stream', + last_text: 'Text Perception', + summary: 'Summary', + keywords: 'Keywords', + topic: 'Topic', + domain: 'Domain', + scene: 'Scene', + speaker_count: 'Number of Speakers', + section_count: 'Number of Sections', + timeLine: 'Perception Timeline', + lastInfo: 'Real-time Perception Dashboard', + }, + explicitDetail: { + episodic_memories: 'Episodic Memories', + semantic_memories: 'Semantic Memories', + content: 'Core Description', + created_at: 'Created At', + emotion: 'Emotion', + core_definition: 'Core Definition', + detailed_notes: 'Detailed Notes', + }, + workingDetail: { + conversationStream: 'Real-time Conversation Stream', + refresh: 'Refresh', + successfulTitle: 'Successful Experience', + question: 'Lessons Learned', + summary: 'Core Insights', + none: 'None' + } }, }; diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 83b78c76..9abc57cf 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -636,10 +636,10 @@ export const zh = { promptEmpty: '在左侧描述您的用例,编排预览将在此处显示。', master: '主管模式', - master_agent: '主管模式', - master_agentDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。', - handoffs: '协作模式', - handoffsDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。', + supervisor: '主管模式', + supervisorDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。', + collaboration: '协作模式', + collaborationDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。', masterConfig: '主管配置', orchestrationMode: '任务分配策略', conditional: '智能分配', @@ -649,6 +649,8 @@ export const zh = { merge: '完整汇总', vote: '关键信息提取', priority: '结构化整合', + addTool: '添加工具', + tool: '工具', }, // 角色管理相关翻译 role: { @@ -1283,11 +1285,6 @@ export const zh = { nodeStatistics: '记忆分类', total: '总计', - Chunk: '长期记忆', - MemorySummary: '情景记忆', - Statement: '情绪记忆', - ExtractedEntity: '短期记忆', - PERCEPTUAL_MEMORY: '感知记忆', WORKING_MEMORY: '工作记忆', SHORT_TERM_MEMORY: '短期记忆', @@ -1296,6 +1293,7 @@ export const zh = { IMPLICIT_MEMORY: '隐性记忆', EMOTIONAL_MEMORY: '情绪记忆', EPISODIC_MEMORY: '情景记忆', + FORGETTING_MANAGEMENT: '遗忘', endUserProfile: '核心档案', editEndUserProfile: '编辑', @@ -1315,6 +1313,33 @@ export const zh = { key_findings: '关键发现', behavior_pattern: '行为模式', growth_trajectory: '成长轨迹', + personality: '性格特点', + core_values: '核心价值观', + + Statement_emotion_keywords: '情感关键词', + Statement_emotion_type: '情感类型', + Statement_emotion_subject: '情感主体', + Statement_importance_score: '重要性评分', + + ExtractedEntity_description: '描述', + ExtractedEntity_name: '内容', + ExtractedEntity_entity_type: '类型', + ExtractedEntity_created_at: '创建时间', + ExtractedEntity_aliases: '别名', + ExtractedEntity_connect_strngth: '连接强度', + ExtractedEntity_importance_score: '重要性评分', + + associative_memory: '关联记忆', + unix: '个', + completeMemory: '完整记忆', + relationshipEvolution: '关系演化', + timelineMemories: '共同记忆时间线', + emotionLine: '情绪随时间变化', + interaction: '互动频率 & 关系阶段', + timelines_memory: '全部', + MemorySummary: '长期沉淀', + Statement: '情绪记忆', + ExtractedEntity: '情景记忆', }, space: { createSpace: '创建空间', @@ -1909,12 +1934,20 @@ export const zh = { "not_contains": '不包含', "startwith": '开始是', "endwith": '结束是', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": '是', + "ne": '不是', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": '是', + "ne": '不是', + }, else_desc: '用于定义当 if 条件不满足时应执行的逻辑。' }, 'http-request': { @@ -1939,6 +1972,7 @@ export const zh = { status_code: '状态码', max_attempts: '最大重试次数', retry_interval: '重试间隔', + errorBranch: '异常分支', }, 'jinja-render': { template: '代码', @@ -1955,12 +1989,17 @@ export const zh = { loop: { cycle_vars: '循环变量', condition: '循环终止条件', + max_loop: '最大循环次数', }, assigner: { assignments: '变量', cover: '覆盖', assign: '设置', - clear: '清空' + clear: '清空', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: '输入变量', @@ -2063,6 +2102,7 @@ export const zh = { }, statementDetail: { wordCloud: '情感分布分析', + totalCount: '样本数', pieces: '条', emotionTags: '高频情绪关键词', joy: '喜悦', @@ -2247,5 +2287,124 @@ export const zh = { orderPayInfo: '支付信息', create_time: '创建时间', }, + forgetDetail: { + title: '遗忘管理系统帮助AI智能管理记忆生命周期,通过自动识别低价值记忆、设置遗忘策略和执行定期清理,优化记忆库存储空间,提升检索效率。', + overviewTitle: '核心指标概览', + totalMemory: '记忆总量', + MemoryHealth: '记忆健康度', + riskOfForgetting: '遗忘风险', + statement_count: '陈述', + entity_count: '实体', + summary_count: '摘要', + chunk_count: '片段', + healthStatus: '健康状态', + average: '平均激活值', + threshold: '阈值参考:', + unhealthy: '不健康', + healthy: '健康', + low_nodes: '低激活节点', + memoryHealthVisualization: '记忆健康可视化', + activationValueDistribution: '激活值分布', + forgettingTrend: '遗忘趋势(近7天)', + + nodes_without_activation: '观察区', + low_activation_nodes: '遗忘区', + health_nodes: '健康区', + average_activation: '平均激活值', + merged_count: '每日融合节点数', + + pending_nodes: '风险节点遗忘池', + content_summary: '内容摘要', + node_type: '节点类型', + last_access_time: '最后激活时间', + activation_value: '当前激活值', + }, + episodicDetail: { + title: '记录你真实经历过的每一个重要场景', + total_all: '情景记忆总数', + all: "全部", + today: '今天', + this_week: '本周', + this_month: '本月', + conversation: "对话", + project_work: "项目/工作", + learning: "学习", + decision: "决策", + important_event: "重要事件", + titleKeywordPlaceholder: '搜索情景标题或内容', + curResult: '当前筛选结果', + unix: '条', + created: '发生时间', + episodic_type: '情景类型', + involved_objects: '涉及对象', + content_records: '情景内容记录', + emotion: '情绪与状态记录', + }, + implicitDetail: { + title: '那些塑造了我的无形力量', + preferences: '我的潜意识偏好', + preferencesDetail: '的联想网络', + portraitTitle: '我的潜意识画像', + portraitSubTitle: '基于您的偏好标签,AI为您生成的个性化洞察', + portrait: '核心特质', + aesthetic: '审美驱动', + creativity: '创造性思维', + literature: '文化敏感度', + technology: '技术亲和力', + interestAreas: '兴趣领域分布', + art: '艺术与设计', + music: '音乐与文化', + tech: '科技与未来', + lifestyle: '生活方式', + habits: '用户习惯分析', + habitsSubTitle: '基于您的行为模式识别的习惯特征', + context_details: '偏好详情', + supporting_evidence: '偏好来源', + specific_examples: '来源', + }, + shortTermDetail: { + title: '短期记忆是AI系统的"工作台",连接即时对话与长期知识库。通过实时捕获、深度检索、智能提取和筛选转化,将临时的非结构化信息转化为有价值的长期知识。', + retrieval_number: '检索次数', + entity: '提取实体', + long_term_number: '长期候选', + shortTermTitle: '深度检索与扩展答案区', + shortTermSubTitle: '存放为回答问题而进行的深度信息检索和由此生成的扩展答案,包含原始问题、检索信息和生成答案。', + longTermTitle: '长期记忆候选池', + longTermSubTitle: '聚合短期记忆,筛选并准备存入长期记忆的内容。这是从短时记忆到长时记忆的"中转站"和"过滤器"。', + answer: '回答', + query: '问题', + noAnswer: '暂无回复', + }, + perceptualDetail: { + last_visual: '视觉感知流', + last_listen: '听觉感知流', + last_text: '文本感知', + summary: '摘要', + keywords: '关键词', + topic: '主题', + domain: '领域', + scene: '场景', + speaker_count: '对话人数', + section_count: '段落数', + timeLine: '感知时间线', + lastInfo: '实时感知仪表盘', + }, + explicitDetail: { + episodic_memories: '情景记忆', + semantic_memories: '语义记忆', + content: '核心描述', + created_at: '创建时间', + emotion: '情绪', + core_definition: '核心定义', + detailed_notes: '详细笔记', + }, + workingDetail: { + conversationStream: '实时对话流', + refresh: '刷新', + successfulTitle: '成功经验', + question: '踩过的坑', + summary: '核心洞察', + none: '无' + } }, } \ No newline at end of file diff --git a/web/src/routes/index.tsx b/web/src/routes/index.tsx index bac985bf..f3bd9c2d 100644 --- a/web/src/routes/index.tsx +++ b/web/src/routes/index.tsx @@ -59,6 +59,8 @@ const componentMap: Record>> = ApiKeyManagement: lazy(() => import('@/views/ApiKeyManagement')), EmotionEngine: lazy(() => import('@/views/EmotionEngine')), StatementDetail: lazy(() => import('@/views/UserMemoryDetail/pages/StatementDetail')), + ForgetDetail: lazy(() => import('@/views/UserMemoryDetail/pages/ForgetDetail')), + MemoryNodeDetail: lazy(() => import('@/views/UserMemoryDetail/pages/index')), SelfReflectionEngine: lazy(() => import('@/views/SelfReflectionEngine')), OrderPayment: lazy(() => import('@/views/OrderPayment')), OrderHistory: lazy(() => import('@/views/OrderHistory')), diff --git a/web/src/routes/routes.json b/web/src/routes/routes.json index a29d1c63..ca6a3271 100644 --- a/web/src/routes/routes.json +++ b/web/src/routes/routes.json @@ -43,7 +43,8 @@ { "path": "/application/config/:id", "element": "ApplicationConfig" }, { "path": "/conversation/:token", "element": "Conversation" }, { "path": "/user-memory/neo4j/:id", "element": "Neo4jUserMemoryDetail" }, - { "path": "/statement/:id", "element": "StatementDetail" } + { "path": "/statement/:id", "element": "StatementDetail" }, + { "path": "/user-memory/detail/:id/:type", "element": "MemoryNodeDetail" } ] }, { diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index c6aa63e8..ce51c622 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -18,7 +18,8 @@ import type { Variable, MemoryConfig, AiPromptModalRef, - Source + Source, + ToolOption } from './types' import type { Model } from '@/views/ModelManagement/types' import { getModelList } from '@/api/models'; @@ -31,6 +32,7 @@ import { memoryConfigListUrl } from '@/api/memory' import CustomSelect from '@/components/CustomSelect' import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' +import ToolList from './components/ToolList' const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => { return ( @@ -47,12 +49,12 @@ const LabelWrapper: FC<{title: string, className?: string; children?: ReactNode} ) } -const SwitchWrapper: FC<{ title: string, desc: string, name: string }> = ({ title, desc, name }) => { +const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[]; needTransition?: boolean; }> = ({ title, desc, name, needTransition = true }) => { const { t } = useTranslation(); return (
- - + + {desc && } ((_props, ref) => { const [formData, setFormData] = useState<{ default_model_config_id?: string, model_parameters?: Config['model_parameters'], + tools: ToolOption[], } | null>(null) const values = Form.useWatch<{ memoryEnabled: boolean; memory_content?: string | number; - webSearch: boolean; } & Config>([], form) const [knowledgeConfig, setKnowledgeConfig] = useState({ knowledge_bases: [] }) const [variableList, setVariableList] = useState([]) const [isSave, setIsSave] = useState(false) const initialized = useRef(false) + const [toolList, setToolList] = useState([]) // 初始化完成标记 useEffect(() => { @@ -139,6 +142,11 @@ const Agent = forwardRef((_props, ref) => { if (isSave) return setIsSave(true) }, [values]) + useEffect(() => { + if (!initialized.current) return + if (isSave) return + setIsSave(true) + }, [toolList]) useEffect(() => { getModels() @@ -149,17 +157,21 @@ const Agent = forwardRef((_props, ref) => { setLoading(true) getApplicationConfig(id as string).then(res => { const response = res as Config - setData(response) + setData({ + ...response, + tools: Array.isArray(response.tools) ? response.tools : [] + }) const { memory, tools } = response form.setFieldsValue({ ...response, memoryEnabled: memory?.enabled || false, memory_content: memory?.memory_content ? Number(memory?.memory_content) : undefined, - webSearch: tools?.web_search?.enabled || false, + tools: Array.isArray(tools) ? tools : [] }) setFormData({ default_model_config_id: response.default_model_config_id, model_parameters: response.model_parameters || {}, + tools: Array.isArray(tools) ? tools : [] }) if (response?.knowledge_retrieval?.knowledge_bases?.length) { getDefaultKnowledgeList(response) @@ -260,8 +272,9 @@ const Agent = forwardRef((_props, ref) => { // 保存Agent配置 const handleSave = (flag = true) => { if (!isSave || !data) return Promise.resolve() - const { memoryEnabled, memory_content, webSearch, ...rest } = values + const { memoryEnabled, memory_content, ...rest } = values const { knowledge_bases = [], ...knowledgeRest } = knowledgeConfig || {} + // 从原数据中获取memory的其他必要属性 const originalMemory = data.memory || ({} as MemoryConfig) @@ -285,15 +298,14 @@ const Agent = forwardRef((_props, ref) => { ...(item.config || {}) })) } as KnowledgeConfig : null, - tools: { - web_search: { - enabled: webSearch, - config: { - web_search: webSearch - } - } - } + tools: toolList.map(vo => ({ + tool_id: vo.tool_id, + operation: vo.operation, + enabled: vo.enabled + })) } + + console.log('params', rest, params) return new Promise((resolve, reject) => { saveAgentConfig(data.app_id, params) @@ -342,6 +354,7 @@ const Agent = forwardRef((_props, ref) => { const updatePrompt = (value: string) => { form.setFieldValue('system_prompt', value) } + return ( <> {loading && } @@ -410,14 +423,12 @@ const Agent = forwardRef((_props, ref) => { data={data?.variables} onUpdate={setVariableList} /> + {/* 工具配置 */} - - - - {/* - */} - - + diff --git a/web/src/views/ApplicationConfig/Cluster.tsx b/web/src/views/ApplicationConfig/Cluster.tsx index ec38c96a..66245446 100644 --- a/web/src/views/ApplicationConfig/Cluster.tsx +++ b/web/src/views/ApplicationConfig/Cluster.tsx @@ -42,7 +42,7 @@ const Cluster = forwardRef((_props, ref) => { const handleSave = (flag = true) => { if (!data) return Promise.resolve() - if (!values.default_model_config_id) { + if (!values.default_model_config_id && values.orchestration_mode === 'supervisor') { message.warning(t('common.selectPlaceholder', { title: t('application.model') })) return Promise.resolve() } @@ -138,17 +138,16 @@ const Cluster = forwardRef((_props, ref) => {
- + ({ + options={['supervisor', 'collaboration'].map((type) => ({ value: type, label: t(`application.${type}`), labelDesc: t(`application.${type}Desc`), - disabled: type === 'handoffs' }))} allowClear={false} /> @@ -192,7 +191,7 @@ const Cluster = forwardRef((_props, ref) => { ))} - + {values?.orchestration_mode !== 'collaboration' && ((_props, ref) => { + + + + + + + + + + + + + {t('episodicDetail.curResult')} ({data.total || 0}{t('episodicDetail.unix')})} + headerType="borderless" + > + {loading + ? + : !data.episodic_memories || data.episodic_memories.length === 0 + ? + : ( + + {data.episodic_memories.map((vo, index) => ( +
setSelected(vo)} + > +
{index + 1}
+
+
{vo.title} {t(`episodicDetail.${getTypeKey(vo.type)}`)}
+
{formatDateTime(vo.created_at)}
+
+
+ ))} +
+ ) + } + +
+ + + + {detailLoading + ? + : !selected || !detail + ? + : ( + +
+ + +
{t('episodicDetail.created')}
{formatDateTime(detail.created_at)}
+ + +
{t('episodicDetail.episodic_type')}
{detail.episodic_type}
+ + {detail.involved_objects.length > 0 && +
{t('episodicDetail.involved_objects')}
+ {detail.involved_objects.map((vo, index) => {vo})} + } +
+
+
+
{t('episodicDetail.content_records')}
+ {detail.content_records.map((vo, index) =>
- {vo}
)} +
+ + {t('episodicDetail.emotion')}: {t(`statementDetail.${detail.emotion}`)} + +
+ ) + } +
+ +
+ + ) +} +export default EpisodicDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx b/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx new file mode 100644 index 00000000..286e71be --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/ExplicitDetail.tsx @@ -0,0 +1,109 @@ +import { type FC, useEffect, useState, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import { useParams } from 'react-router-dom' +import { List, Skeleton, Row, Col } from 'antd' +import RbCard from '@/components/RbCard/Card' +import { + getExplicitMemory, +} from '@/api/memory' +import { formatDateTime } from '@/utils/format' +import Empty from '@/components/Empty' +import ExplicitDetailModal from '../components/ExplicitDetailModal' + +export interface EpisodicMemory { + id: string; + title: string; + content: string; + created_at: number; +} +export interface SemanticMemory { + id: string; + name: string; + entity_type: string; + core_definition: string; + created_at: number; +} +interface Data { + episodic_memories: EpisodicMemory[]; + semantic_memories: SemanticMemory[] +} + +export interface ExplicitDetailModalRef { + handleOpen: (vo: EpisodicMemory | SemanticMemory) => void; +} + +const ExplicitDetail: FC = () => { + const { t } = useTranslation() + const { id } = useParams() + const explicitDetailModalRef = useRef(null) + const [loading, setLoading] = useState(false) + const [data, setData] = useState({ episodic_memories: [], semantic_memories: [] }) + + useEffect(() => { + if (!id) return + getData() + }, [id]) + + const getData = () => { + if (!id) return + setLoading(true) + getExplicitMemory(id).then((res) => { + const response = res as Data + setData(response) + setLoading(false) + }) + .finally(() => { + setLoading(false) + }) + } + const handleView = (item: EpisodicMemory | SemanticMemory) => { + explicitDetailModalRef.current?.handleOpen(item) + } + return ( +
+
{t('explicitDetail.episodic_memories')}
+ {loading ? + + : data.episodic_memories?.length > 0 ? ( + + {data.episodic_memories.map(item => ( + + handleView(item)} + > +
{formatDateTime(item.created_at)}
+
{item.content}
+
+ + ))} +
+ ) : } + +
{t('explicitDetail.semantic_memories')}
+ {loading ? + + : data.semantic_memories?.length > 0 ? ( + + {data.semantic_memories.map(item => ( + + handleView(item)} + > +
{item.core_definition}
+
+ + ))} +
+ ) : } + + +
+ ) +} +export default ExplicitDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx new file mode 100644 index 00000000..602dbf25 --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/ForgetDetail.tsx @@ -0,0 +1,158 @@ +import { type FC, useEffect, useState, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { useParams } from 'react-router-dom' +import { Row, Col, Progress } from 'antd' +import RbCard from '@/components/RbCard/Card' +import { + getForgetStats, +} from '@/api/memory' +import type { ForgetData } from '../types' +import ActivationMetricsPieCard from '../components/ActivationMetricsPieCard' +import RecentTrendsLineCard from '../components/RecentTrendsLineCard' +import Table from '@/components/Table' +import { formatDateTime } from '@/utils/format' +import StatusTag from '@/components/StatusTag' + +const statusTagColors: Record = { + statement: 'success', + entity: 'purple', + summary: 'default', + chunk: 'warning', +} + +const ForgetDetail: FC = () => { + const { t } = useTranslation() + const { id } = useParams() + const [loading, setLoading] = useState(false) + const [data, setData] = useState({} as ForgetData) + + useEffect(() => { + if (!id) return + getData() + }, [id]) + + const getData = () => { + if (!id) return + setLoading(true) + getForgetStats(id).then((res) => { + const response = res as ForgetData + setData(response) + setLoading(false) + }) + .finally(() => { + setLoading(false) + }) + } + const chartData = useMemo(() => { + const { activation_metrics } = data + if (!activation_metrics) return [] + + let health_nodes = (activation_metrics.total_nodes || 0) - (activation_metrics.low_activation_nodes || 0) - (activation_metrics.nodes_without_activation || 0) + + return [ + { name: t('forgetDetail.health_nodes'), value: health_nodes }, + { name: t('forgetDetail.nodes_without_activation'), value: activation_metrics.nodes_without_activation || 0 }, + { name: t('forgetDetail.low_activation_nodes'), value: activation_metrics.low_activation_nodes || 0 }, + ] + + }, [data.activation_metrics, t]) + + const seriesList = useMemo(() => { + const { recent_trends = [] } = data + if (!recent_trends || recent_trends.length === 0) return { chartData: [], seriesList: [] } + + return { + chartData: recent_trends, + seriesList: ['merged_count', 'average_activation'] + } + }, [data.recent_trends]) + + return ( +
+
{t('forgetDetail.title')}
+
{t('forgetDetail.overviewTitle')}
+ + + +
{t('forgetDetail.totalMemory')}
+
{data?.activation_metrics?.total_nodes ?? 0}
+
+ {['statement_count', 'entity_count', 'summary_count', 'chunk_count'].map((key, index) => ( +
+
{data?.node_distribution?.[key as keyof typeof data.node_distribution] ?? 0}
+
{t(`forgetDetail.${key}`)}
+
+ ))} +
+
+ + + +
{t('forgetDetail.MemoryHealth')}
+
{data?.activation_metrics?.average_activation_value ?? 0}
+ +
+
{t('forgetDetail.healthStatus')}
+
{data?.activation_metrics?.average_activation_value > data.activation_metrics?.forgetting_threshold ? t('forgetDetail.healthy') : t('forgetDetail.unhealthy')}
+
+ {t('forgetDetail.average')}
+ {t('forgetDetail.threshold')}{data.activation_metrics?.forgetting_threshold ?? 0} +
+
+
+ + + +
{t('forgetDetail.riskOfForgetting')}
+
{data.activation_metrics?.low_activation_nodes ?? 0}
+
{t('forgetDetail.low_nodes')}
+
+ +
+ +
{t('forgetDetail.memoryHealthVisualization')}
+ + + + + + + + +
{t('forgetDetail.pending_nodes')}
+
{content_summary}
+ }, + { + title: t('forgetDetail.node_type'), + dataIndex: 'node_type', + key: 'node_type', + render: (node_type: string) => { + return } + }, + { + title: t('forgetDetail.last_access_time'), + dataIndex: 'last_access_time', + key: 'last_access_time', + render: (last_access_time) => formatDateTime(last_access_time, 'YYYY-MM-DD HH:mm') + }, + { + title: t('forgetDetail.activation_value'), + dataIndex: 'activation_value', + key: 'activation_value', + }, + ]} + pagination={false} + /> + + ) +} +export default ForgetDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx new file mode 100644 index 00000000..ef23463a --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/ImplicitDetail.tsx @@ -0,0 +1,34 @@ +import { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import { Row, Col } from 'antd' + +import Preferences from '../components/Preferences' +import Portrait from '../components/Portrait' +import InterestAreas from '../components/InterestAreas' +import Habits from '../components/Habits' + +const ImplicitDetail: FC = () => { + const { t } = useTranslation() + + return ( +
+
{t('implicitDetail.title')}
+ + + +
{t('implicitDetail.portraitTitle')}
+
{t('implicitDetail.portraitSubTitle')}
+ +
+ + + + + + + + + + ) +} +export default ImplicitDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx b/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx new file mode 100644 index 00000000..7e2d5353 --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/PerceptualDetail.tsx @@ -0,0 +1,32 @@ +import { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import { Row, Col } from 'antd' + +import PerceptualLastInfo from '../components/PerceptualLastInfo' +import Timeline from '../components/Timeline' + +const PerceptualDetail: FC = () => { + const { t } = useTranslation() + + return ( +
+
{t('perceptualDetail.lastInfo')}
+ + +
+ + + + + + + + + + +
{t('perceptualDetail.timeLine')}
+ + + ) +} +export default PerceptualDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx new file mode 100644 index 00000000..6cc8eafc --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx @@ -0,0 +1,114 @@ +import { type FC, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useParams } from 'react-router-dom' +import { Space, Skeleton } from 'antd' +import { + getShortTerm, +} from '@/api/memory' +import Empty from '@/components/Empty' + +interface ShortTermItem { + retrieval: Array<{ query: string; retrieval: string[]; }>; + message: string; + answer: string; +} +interface LongTermItem { + query: string; + retrieval: string; +} +interface ShortData { + short_term: ShortTermItem[]; + long_term: LongTermItem[]; + entity: number; + retrieval_number: number; + long_term_number: number; +} +const ShortTermDetail: FC = () => { + const { t } = useTranslation() + const { id } = useParams() + const [loading, setLoading] = useState(false) + const [data, setData] = useState({} as ShortData) + + useEffect(() => { + if (!id) return + getData() + }, [id]) + + const getData = () => { + if (!id) return + setLoading(true) + getShortTerm(id).then((res) => { + const response = res as ShortData + setData(response) + setLoading(false) + }) + .finally(() => { + setLoading(false) + }) + } + + return ( +
+
+
{t('shortTermDetail.title')}
+ +
+ {(['retrieval_number', 'entity', 'long_term_number'] as const).map(key => ( +
+
{(data as any)[key] ?? 0}
+ {t(`shortTermDetail.${key}`)} +
+ ))} +
+
+ + +
{t('shortTermDetail.shortTermTitle')}
+
{t('shortTermDetail.shortTermSubTitle')}
+ + {loading + ? + : !data.short_term || data.short_term.length === 0 + ? + :data.short_term?.map((vo, voIdx) => ( +
+
{vo.message}
+ + {vo.retrieval.map((item, index) => ( +
+
{t('shortTermDetail.query')}: {item.query}
+
{t('shortTermDetail.answer')}:
+ {item.retrieval.length > 0 ? item.retrieval.map((retrieval, retrievalIdx) => ( +
- {retrieval}
+ )) :
{t('shortTermDetail.noAnswer')}
} +
+ ))} +
+
{t('shortTermDetail.answer')}
+
{vo.answer}
+
+
+
+ )) + } +
+ +
{t('shortTermDetail.longTermTitle')}
+
{t('shortTermDetail.shortTermSubTitle')}
+ + {loading + ? + : !data.long_term || data.long_term.length === 0 + ? + : data.long_term?.map((vo, voIdx) => ( +
+
{vo.query}
+
{vo.retrieval}
+
+ )) + } +
+
+ ) +} +export default ShortTermDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx index 744c244d..e6ddfd20 100644 --- a/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/StatementDetail.tsx @@ -1,53 +1,26 @@ -import { type FC, useEffect, useState } from 'react' -import { useParams } from 'react-router-dom' +import { type FC } from 'react' import { Row, Col, Space } from 'antd'; import WordCloud from '../components/WordCloud' import EmotionTags from '../components/EmotionTags' import Health from '../components/Health' import Suggestions from '../components/Suggestions' -import PageHeader from '../components/PageHeader' -import { - getEndUserProfile, -} from '@/api/memory' const StatementDetail: FC = () => { - const { id } = useParams() - const [name, setName] = useState('') - useEffect(() => { - if (!id) return - getData() - }, [id]) - - const getData = () => { - if (!id) return - getEndUserProfile(id).then((res) => { - const response = res as { other_name: string; id: string; } - setName(response.other_name ?? response.id) - }) - } return ( -
- -
- -
- - - - - - - - - - - - + + + + + + + + + + + + ) } diff --git a/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx new file mode 100644 index 00000000..3093d1ae --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/WorkingDetail.tsx @@ -0,0 +1,209 @@ +import { type FC, useEffect, useState, useMemo } from 'react' +import clsx from 'clsx' +import { useTranslation } from 'react-i18next' +import { useParams } from 'react-router-dom' +import { Row, Col, Select, Form, Space, Skeleton, Input, Button, Divider } from 'antd' +import RbCard from '@/components/RbCard/Card' +import { + getConversations, + getConversationMessages, + getConversationDetail, +} from '@/api/memory' +import { formatDateTime } from '@/utils/format' +import Tag from '@/components/Tag' +import RbAlert from '@/components/RbAlert' +import Empty from '@/components/Empty' +import ChatContent from '@/components/Chat/ChatContent' +import type { ChatItem } from '@/components/Chat/types' +import PageLoading from '@/components/Empty/PageLoading' + +interface Conversation { + title: string; + id: string; +} +interface Detail { + theme: string; + theme_intro: string; + summary: string; + question: string[]; + takeaways: string[]; + info_score: number; +} + +const WorkingDetail: FC = () => { + const { t } = useTranslation() + const { id } = useParams() + const [form] = Form.useForm() + const [loading, setLoading] = useState(false) + const [data, setData] = useState([]) + const [messagesLoading, setMessagesLoading] = useState(false) + const [messages, setMessages] = useState([]) + const [detailLoading, setDetailLoading] = useState(false) + const [detail, setDetail] = useState(null) + const [selected, setSelected] = useState(null) + + useEffect(() => { + if (!id) return + getData() + }, [id]) + + const getData = () => { + if (!id) return + setLoading(true) + setSelected(null) + setDetail(null) + setData([]) + getConversations(id).then((res) => { + const response = res as Conversation[] + setData(response) + setSelected(response[0] || null) + }) + .finally(() => { + setLoading(false) + }) + } + + useEffect(() => { + if (!id || !selected || !selected.id) return + getDetail(selected.id) + }, [id, selected]) + + const getDetail = (conversationId: string) => { + if (!id || !conversationId) return + + setDetail(null) + setMessages([]) + setDetailLoading(true) + setMessagesLoading(true) + getConversationMessages(id, conversationId) + .then(res => { + setMessages(res as ChatItem[]) + }) + .finally(() => { + setMessagesLoading(false) + }) + getConversationDetail(id, conversationId) + .then(res => { + setDetail(res as Detail) + }) + .finally(() => { + setDetailLoading(false) + }) + } + const timeRange = useMemo(() => { + const times = messages.filter(m => m.created_at).map(m => Number(m.created_at)) + if (times.length === 0) return '' + const minTime = Math.min(...times) + const maxTime = Math.max(...times) + return `${formatDateTime(minTime, 'YYYY.MM')} - ${formatDateTime(maxTime, 'YYYY.MM')}` + }, [messages]) + + return ( +
+ {loading + ? + : data.length === 0 + ? + :( + +
+
+ {data.map(item => ( +
+
getDetail(item.id)} + > + {item.title} +
+
+ ))} +
+ + {selected && <> + +
{selected.title}
+
{timeRange}
+ + + + getDetail(selected.id)}>{t('workingDetail.refresh')}} + className="rb:mt-4!" + headerClassName='rb:bg-[#F6F8FC]! rb:border-b! rb:border-b-[#DFE4ED]! rb:min-h-11!' + headerType="borderless" + bodyClassName="rb:h-[calc(100vh-210px)]" + > + {messagesLoading + ? + : messages.length === 0 + ? + : ( + formatDateTime(item.created_at)} + /> + ) + } + + + + + {detailLoading + ? + : detail + ? <> + <> +
{t('workingDetail.successfulTitle')}
+ + {detail.takeaways.length > 0 + ? ( +
    + {detail.takeaways.map(vo =>
  • {vo}
  • )} +
+ ) + : + } + + + <> + +
{t('workingDetail.question')}
+ + {detail.question.length > 0 + ? ( +
    + {detail.question.map(vo =>
  • {vo}
  • )} +
+ ) + : + } + + + <> + +
{t('workingDetail.summary')}
+ {detail.summary + ? {detail.summary} + : + } + + + : + } +
+ + + + } + + ) + } + + ) +} +export default WorkingDetail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/pages/index.tsx b/web/src/views/UserMemoryDetail/pages/index.tsx new file mode 100644 index 00000000..e734fd44 --- /dev/null +++ b/web/src/views/UserMemoryDetail/pages/index.tsx @@ -0,0 +1,74 @@ +import { type FC, useEffect, useState, useMemo } from 'react' +import { useParams, useNavigate } from 'react-router-dom' +import { useTranslation } from 'react-i18next' +import { Dropdown } from 'antd' + +import PageHeader from '../components/PageHeader' +import StatementDetail from './StatementDetail' +import ForgetDetail from './ForgetDetail' +import ImplicitDetail from './ImplicitDetail' +import ShortTermDetail from './ShortTermDetail' +import PerceptualDetail from './PerceptualDetail' +import EpisodicDetail from './EpisodicDetail' +import ExplicitDetail from './ExplicitDetail' +import WorkingDetail from './WorkingDetail' +import { + getEndUserProfile, +} from '@/api/memory' + +const Detail: FC = () => { + const { t } = useTranslation() + const { id, type } = useParams() + const navigate = useNavigate() + const [name, setName] = useState('') + useEffect(() => { + if (!id) return + getData() + }, [id]) + + const getData = () => { + if (!id) return + getEndUserProfile(id).then((res) => { + const response = res as { other_name: string; id: string; } + setName(response.other_name || response.id) + }) + } + const items = useMemo(() => { + return ['PERCEPTUAL_MEMORY', 'WORKING_MEMORY', 'EMOTIONAL_MEMORY', 'SHORT_TERM_MEMORY', 'IMPLICIT_MEMORY', 'EPISODIC_MEMORY', 'EXPLICIT_MEMORY', 'FORGETTING_MANAGEMENT'] + .map(key => ({ key, label: t(`userMemory.${key}`) })) + }, [t]) + const onClick = ({ key }: { key: string }) => { + navigate(`/user-memory/detail/${id}/${key}`, { replace: true }) + } + + return ( +
+ +
+ - {type ? t(`userMemory.${type}`) : ''} +
+
+ + } + /> +
+ {type === 'EMOTIONAL_MEMORY' && } + {type === 'FORGETTING_MANAGEMENT' && } + {type === 'IMPLICIT_MEMORY' && } + {type === 'SHORT_TERM_MEMORY' && } + {type === 'PERCEPTUAL_MEMORY' && } {/** TODO */} + {type === 'EPISODIC_MEMORY' && } + {type === 'WORKING_MEMORY' && } {/** TODO */} + {type === 'EXPLICIT_MEMORY' && } {/** TODO */} +
+
+ ) +} + +export default Detail \ No newline at end of file diff --git a/web/src/views/UserMemoryDetail/types.ts b/web/src/views/UserMemoryDetail/types.ts index 8fd050a9..afed31fd 100644 --- a/web/src/views/UserMemoryDetail/types.ts +++ b/web/src/views/UserMemoryDetail/types.ts @@ -44,6 +44,7 @@ export interface Data { export interface BaseProperties { content: string; created_at: number; + associative_memory: number; } export interface StatementNodeProperties { temporal_info: string; @@ -51,12 +52,21 @@ export interface StatementNodeProperties { statement: string; valid_at: string; created_at: number; + emotion_keywords: string[]; + emotion_type: string; + emotion_subject: string; + importance_score: number; + associative_memory: number; } export interface ExtractedEntityNodeProperties { description: string; name: string; entity_type: string; created_at: number; + aliases: string; + connect_strngth: string; + importance_score: number; + associative_memory: number; } export interface MemorySummaryNode { id: string; @@ -72,7 +82,7 @@ export interface MemorySummaryNode { created_at: number; } caption: string; - + associative_memory: number; } export interface Node { @@ -140,4 +150,41 @@ export interface AboutMeRef { } export interface EndUserProfileRef { data: EndUser | null +} + + +export interface ForgetData { + activation_metrics: { + total_nodes: number; + nodes_with_activation: number; + nodes_without_activation: number; + average_activation_value: number; + low_activation_nodes: number; + timestamp: number; + forgetting_threshold: number; + }, + node_distribution: { + statement_count: number; + entity_count: number; + summary_count: number; + chunk_count: number; + }, + recent_trends: { + date: string; + merged_count: number; + average_activation: number; + total_nodes: number; + execution_time: number; + }[], + pending_nodes: { + node_id: string; + node_type: string; + content_summary: string; + activation_value: number; + last_access_time: number; + }[], + timestamp: number; +} +export interface GraphDetailRef { + handleOpen: (vo: Node) => void } \ No newline at end of file diff --git a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx index 571f1e4e..aaaa2ab5 100644 --- a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx @@ -1,5 +1,5 @@ import { forwardRef, useImperativeHandle, useState } from 'react'; -import { Form, Input, Select, Checkbox, InputNumber } from 'antd'; +import { Form, Input, Select, InputNumber } from 'antd'; import { useTranslation } from 'react-i18next'; import type { ChatVariableModalRef } from './types' @@ -26,7 +26,7 @@ const ChatVariableModal = forwardRef(); const [loading, setLoading] = useState(false) const [editIndex, setEditIndex] = useState(undefined) - const typeValue = Form.useWatch('type', form); + const type = Form.useWatch('type', form); // 封装取消方法,添加关闭弹窗逻辑 const handleClose = () => { @@ -39,7 +39,8 @@ const ChatVariableModal = forwardRef { setVisible(true); if (variable) { - form.setFieldsValue(variable) + const { default: _, ...rest } = variable + form.setFieldsValue({ ...rest }) setEditIndex(index) } else { form.resetFields(); @@ -49,7 +50,7 @@ const ChatVariableModal = forwardRef { form.validateFields().then((values) => { - refresh({ ...values }, editIndex) + refresh({ ...values, default: values.defaultValue }, editIndex) handleClose() }) } @@ -90,52 +91,36 @@ const ChatVariableModal = forwardRef - ); - } - return ; - }} - - - + {type === 'number' + ? + : type === 'boolean' + ? + } + - - - {t('workflow.config.parameter-extractor.required')} - ); diff --git a/web/src/views/Workflow/components/AddChatVariable/index.tsx b/web/src/views/Workflow/components/AddChatVariable/index.tsx index f765b5eb..7ebce7df 100644 --- a/web/src/views/Workflow/components/AddChatVariable/index.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/index.tsx @@ -40,7 +40,7 @@ const AddChatVariable = forwardRef(({ } const handleSave = (value: ChatVariable, index?: number) => { const list = [...variables] - if (index && index > -1) { + if (typeof index === 'number' && index > -1) { list[index] = value } else { list.push(value) @@ -75,17 +75,15 @@ const AddChatVariable = forwardRef(({ dataSource={variables} renderItem={(item, index) => ( -
+
{item.name} ({t(`workflow.config.parameter-extractor.${item.type}`)})
- {item.required ? t('workflow.config.parameter-extractor.required') : ''} -
{item.description}
- +
handleEdit(index)} diff --git a/web/src/views/Workflow/components/AddChatVariable/types.ts b/web/src/views/Workflow/components/AddChatVariable/types.ts index ab00ae69..5d9aa7b0 100644 --- a/web/src/views/Workflow/components/AddChatVariable/types.ts +++ b/web/src/views/Workflow/components/AddChatVariable/types.ts @@ -11,8 +11,8 @@ export interface VariableFormData { name: string; type: ChatVariable['type']; description?: string; - required?: boolean; - defaultValue?: any; + defaultValue?: string; + default?: string; } export interface ChatVariableModalRef { diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index 2d12f3f0..c487f2f4 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -1,4 +1,4 @@ -import { type FC, useState } from 'react'; +import { type FC, useState, useEffect } from 'react'; import { LexicalComposer } from '@lexical/react/LexicalComposer'; import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'; import { ContentEditable } from '@lexical/react/LexicalContentEditable'; @@ -23,6 +23,7 @@ interface LexicalEditorProps { options: Suggestion[]; variant?: 'outlined' | 'borderless'; height?: number; + enableJinja2?: boolean; } const theme = { @@ -33,6 +34,15 @@ const theme = { }, }; +const jinja2Theme = { + ...theme, + code: 'jinja2-expression', + text: { + ...theme.text, + code: 'jinja2-inline', + }, +}; + const Editor: FC =({ placeholder = "请输入内容...", value = "", @@ -40,19 +50,62 @@ const Editor: FC =({ options, variant = 'borderless', height = 60, + enableJinja2 = false, }) => { + const [_count, setCount] = useState(0); + + useEffect(() => { + if (enableJinja2) { + const styleId = 'jinja2-styles'; + let existingStyle = document.getElementById(styleId); + + if (!existingStyle) { + const style = document.createElement('style'); + style.id = styleId; + style.textContent = ` + .jinja2-expression { + background-color: #f6f8fa !important; + border: 1px solid #d1d9e0 !important; + border-radius: 3px !important; + padding: 2px 4px !important; + font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important; + font-size: 13px !important; + color: #0969da !important; + } + .jinja2-inline { + background-color: #f6f8fa !important; + padding: 1px 3px !important; + border-radius: 2px !important; + font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important; + font-size: 13px !important; + color: #0969da !important; + } + .editor-paragraph { + margin: 0; + } + .editor-paragraph:has-text('{') .editor-text, + .editor-paragraph:has-text('[') .editor-text { + font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important; + } + `; + document.head.appendChild(style); + } + } + }, [enableJinja2]); const initialConfig = { namespace: 'AutocompleteEditor', - theme, - nodes: [ + theme: enableJinja2 ? jinja2Theme : theme, + nodes: enableJinja2 ? [ + // 当启用jinja2时,不使用VariableNode,使用普通文本 + ] : [ // HeadingNode, // QuoteNode, // ListItemNode, // ListNode, // LinkNode, // CodeNode, - VariableNode + VariableNode, ], onError: (error: Error) => { console.error(error); @@ -96,9 +149,9 @@ const Editor: FC =({ /> - + { setCount(count) }} onChange={onChange} /> - +
); diff --git a/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx index 5c5d3956..34ef3196 100644 --- a/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx @@ -17,7 +17,7 @@ export interface Suggestion { disabled?: boolean; // 标记是否禁用 } -const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => { +const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => { const [editor] = useLexicalComposerContext(); const [showSuggestions, setShowSuggestions] = useState(false); const [selectedIndex, setSelectedIndex] = useState(0); @@ -82,7 +82,32 @@ const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => { }, [editor]); const insertMention = (suggestion: Suggestion) => { - editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion }); + if (enableJinja2) { + // 在jinja2模式下,插入{{variable}}格式的文本 + editor.update(() => { + const selection = $getSelection(); + if ($isRangeSelection(selection)) { + const anchorNode = selection.anchor.getNode(); + const anchorOffset = selection.anchor.offset; + const nodeText = anchorNode.getTextContent(); + + // 移除触发字符'/' + const textBefore = nodeText.substring(0, anchorOffset - 1); + const textAfter = nodeText.substring(anchorOffset); + const newText = textBefore + `{{${suggestion.value}}}` + textAfter; + + anchorNode.setTextContent(newText); + + // 设置光标位置到插入文本之后 + const newOffset = textBefore.length + `{{${suggestion.value}}}`.length; + selection.anchor.offset = newOffset; + selection.focus.offset = newOffset; + } + }); + } else { + // 普通模式下使用VariableNode + editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion }); + } setShowSuggestions(false); }; diff --git a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx index 963f824b..ed07392d 100644 --- a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx @@ -14,18 +14,23 @@ const CharacterCountPlugin = ({ setCount, onChange }: { setCount: (count: number let serializedContent = ''; // Traverse all nodes and serialize properly + const paragraphs: string[] = []; root.getChildren().forEach(child => { if ($isParagraphNode(child)) { + let paragraphContent = ''; child.getChildren().forEach(node => { if ($isVariableNode(node)) { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } else { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } }); + paragraphs.push(paragraphContent); } }); + serializedContent = paragraphs.join('\n'); + setCount(serializedContent.length); onChange?.(serializedContent); }); diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 4059b300..33e31199 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -8,14 +8,31 @@ import { type Suggestion } from '../plugin/AutocompletePlugin' interface InitialValuePluginProps { value: string; options?: Suggestion[]; + enableJinja2?: boolean; } -const InitialValuePlugin: React.FC = ({ value, options = [] }) => { +const InitialValuePlugin: React.FC = ({ value, options = [], enableJinja2 = false }) => { const [editor] = useLexicalComposerContext(); - const initializedRef = useRef(false); + const prevValueRef = useRef(''); + const isUserInputRef = useRef(false); useEffect(() => { - if (!initializedRef.current && value) { + // 监听编辑器变化,标记是否为用户输入 + const removeListener = editor.registerUpdateListener(({ editorState }) => { + editorState.read(() => { + const root = $getRoot(); + const textContent = root.getTextContent(); + if (textContent !== prevValueRef.current) { + isUserInputRef.current = true; + } + }); + }); + + return removeListener; + }, [editor]); + + useEffect(() => { + if (value !== prevValueRef.current && !isUserInputRef.current) { editor.update(() => { const root = $getRoot(); root.clear(); @@ -26,8 +43,13 @@ const InitialValuePlugin: React.FC = ({ value, options parts.forEach(part => { const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/); const contextMatch = part.match(/^\{\{context\}\}$/); + const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/); + + if (enableJinja2) { + paragraph.append($createTextNode(part)); + return; + } - // 匹配{{context}}格式 if (contextMatch) { const contextSuggestion = options.find(s => s.isContext && s.label === 'context'); if (contextSuggestion) { @@ -38,7 +60,19 @@ const InitialValuePlugin: React.FC = ({ value, options return } - // 匹配普通变量{{nodeId.label}}格式 + if (conversationMatch) { + const [_, variableName] = conversationMatch; + const conversationSuggestion = options.find(s => + s.group === 'CONVERSATION' && s.label === variableName + ); + if (conversationSuggestion) { + paragraph.append($createVariableNode(conversationSuggestion)); + } else { + paragraph.append($createTextNode(part)); + } + return + } + if (match) { const [_, nodeId, label] = match; @@ -60,11 +94,12 @@ const InitialValuePlugin: React.FC = ({ value, options }); root.append(paragraph); - }); - - initializedRef.current = true; + }, { discrete: true }); } - }, [options]); + + prevValueRef.current = value; + isUserInputRef.current = false; + }, [value, options, editor, enableJinja2]); return null; }; diff --git a/web/src/views/Workflow/components/Editor/plugin/JsonHighlightPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/JsonHighlightPlugin.tsx new file mode 100644 index 00000000..93180f79 --- /dev/null +++ b/web/src/views/Workflow/components/Editor/plugin/JsonHighlightPlugin.tsx @@ -0,0 +1,109 @@ +import { useEffect } from 'react'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import { $getRoot, $getSelection, $isRangeSelection, TextNode, $createTextNode } from 'lexical'; + +const JsonHighlightPlugin = () => { + const [editor] = useLexicalComposerContext(); + + useEffect(() => { + return editor.registerNodeTransform(TextNode, (textNode: TextNode) => { + const text = textNode.getTextContent(); + + // Check if text contains JSON-like patterns + if (containsJsonPatterns(text)) { + const parent = textNode.getParent(); + if (!parent) return; + + // Split text into tokens and create new nodes with appropriate classes + const tokens = tokenizeJson(text); + const newNodes = tokens.map(token => { + const newNode = $createTextNode(token.text); + + // Set format based on token type + switch (token.type) { + case 'string': + newNode.setFormat('code'); + newNode.setStyle('color: #032f62'); + break; + case 'number': + newNode.setFormat('code'); + newNode.setStyle('color: #005cc5'); + break; + case 'boolean': + newNode.setFormat('code'); + newNode.setStyle('color: #d73a49'); + break; + case 'null': + newNode.setFormat('code'); + newNode.setStyle('color: #6f42c1'); + break; + case 'key': + newNode.setFormat('code'); + newNode.setStyle('color: #22863a; font-weight: bold'); + break; + case 'punctuation': + newNode.setFormat('code'); + newNode.setStyle('color: #24292e'); + break; + } + + return newNode; + }); + + // Replace the original text node with the new highlighted nodes + if (newNodes.length > 1) { + textNode.replace(newNodes[0]); + for (let i = 1; i < newNodes.length; i++) { + newNodes[i - 1].insertAfter(newNodes[i]); + } + } + } + }); + }, [editor]); + + return null; +}; + +function containsJsonPatterns(text: string): boolean { + // Check for JSON-like patterns + return /[{}\[\]:,]/.test(text) || + /"[^"]*"/.test(text) || + /\b\d+(\.\d+)?\b/.test(text) || + /\b(true|false|null)\b/.test(text); +} + +function tokenizeJson(text: string): Array<{text: string, type: string}> { + const tokens: Array<{text: string, type: string}> = []; + const regex = /("[^"]*")|([{}\[\]:,])|(\b\d+(?:\.\d+)?\b)|(\b(?:true|false|null)\b)|(\s+)|([^\s{}\[\]:,"]+)/g; + + let match; + while ((match = regex.exec(text)) !== null) { + const [fullMatch, string, punctuation, number, boolean, whitespace, other] = match; + + if (string) { + // Check if it's a key (followed by colon) + const afterMatch = text.slice(match.index + fullMatch.length).trim(); + if (afterMatch.startsWith(':')) { + tokens.push({ text: fullMatch, type: 'key' }); + } else { + tokens.push({ text: fullMatch, type: 'string' }); + } + } else if (punctuation) { + tokens.push({ text: fullMatch, type: 'punctuation' }); + } else if (number) { + tokens.push({ text: fullMatch, type: 'number' }); + } else if (boolean) { + if (fullMatch === 'null') { + tokens.push({ text: fullMatch, type: 'null' }); + } else { + tokens.push({ text: fullMatch, type: 'boolean' }); + } + } else if (whitespace || other) { + tokens.push({ text: fullMatch, type: 'text' }); + } + } + + return tokens; +} + +export default JsonHighlightPlugin; \ No newline at end of file diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index a2f6d930..973a503c 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -13,13 +13,15 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { const handleNodeSelect = (selectedNodeType: any) => { const parentBBox = node.getBBox(); const cycleId = data.cycle; - + + const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: parentBBox.x, y: parentBBox.y, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index b0b8d4ce..37feb2dc 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -75,12 +75,15 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { const parentBBox = node.getBBox(); const centerX = parentBBox.x + 24; // 默认节点宽度的一半 const centerY = parentBBox.y + 50; // 默认节点高度的一半 - + + const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const cycleStartNode = graph.addNode({ ...graphNodeLibrary.cycleStart, x: centerX, y: centerY, + id: cycleStartNodeId, data: { + id: cycleStartNodeId, type: 'cycle-start', parentId: node.id, isDefault: true, // 标记为默认节点,不可删除 diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 0be6fba1..9a644438 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -43,12 +43,14 @@ const PortClickHandler: React.FC = ({ graph }) => { const newY = sourceBBox.y; // 创建新节点 + const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: newX, y: newY, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx index 34c133c7..97f28668 100644 --- a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx +++ b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx @@ -1,6 +1,6 @@ import { type FC } from 'react' import { useTranslation } from 'react-i18next'; -import { Form, Input, Button, Row, Col, Select } from 'antd' +import { Form, Input, Row, Col, Select, InputNumber, Radio } from 'antd' import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import VariableSelect from '../VariableSelect' @@ -11,6 +11,23 @@ interface AssignmentListProps { options: Suggestion[]; } +const operationsObj = { + number: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + { value: 'add', label: 'workflow.config.assigner.add' }, + { value: 'subtract', label: 'workflow.config.assigner.subtract' }, + { value: 'multiply', label: 'workflow.config.assigner.multiply' }, + { value: 'divide', label: 'workflow.config.assigner.divide' }, + ], + default: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + ], +} + const AssignmentList: FC = ({ parentName, options = [], @@ -27,6 +44,11 @@ const AssignmentList: FC = ({ add({ operation: 'cover'})} />
{fields.map(({ key, name, ...restField }) => { + const variableSelector = form.getFieldValue([parentName, name, 'variable_selector']); + const selectedOption = options.find(option => `{{${option.value}}}` === variableSelector); + const dataType = selectedOption?.dataType; + const operationOptions = dataType === 'number' ? operationsObj.number : operationsObj.default; + return (
@@ -40,6 +62,10 @@ const AssignmentList: FC = ({ placeholder={t('common.pleaseSelect')} options={options} popupMatchSelectWidth={false} + onChange={() => { + form.setFieldValue([parentName, name, 'operation'], undefined); + form.setFieldValue([parentName, name, 'value'], undefined); + }} /> @@ -50,11 +76,11 @@ const AssignmentList: FC = ({ noStyle > ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} + placeholder={t('common.pleaseSelect')} /> @@ -280,11 +323,52 @@ const CaseList: FC = ({ - {!hideRightField && ( - - - - )} + {!hideRightField && <> + {leftFieldType === 'number' + ? +
+ + ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} @@ -104,14 +141,57 @@ const ConditionList: FC = ({ onClick={() => remove(field.name)} /> - - {!hideRightField && ( - - - - - - )} + + {!hideRightField && <> + {leftFieldType === 'number' + ? + + + (({ { - console.log('value record', value) - handleChange(record.key, 'type', value) - }} - /> - ), - }, - { - title: t('workflow.config.value'), - dataIndex: 'value', - width: '45%', - render: (text: string, record: TableRow) => { - if (record.type === 'file') { - - return ( - ( + + { - console.log('value record', value) - handleChange(record.key, 'value', value) - }} + filterBooleanType={filterBooleanType} + popupMatchSelectWidth={false} /> - ) - } - return ( - { - console.log('value record', value) - handleChange(record.key, 'value', value) - }} - /> + ) }, - }, - { - title: '', - width: '10%', - render: (_: any, record: TableRow, index: number) => ( -
}} - scroll={{ x: 'max-content' }} - /> - {!title && - - } + + {(fields, { add, remove }) => { + const AddButton = ({ block = false }: { block?: boolean }) => ( + + ); + + return ( + <> + {title && ( +
+
{title}
+ +
+ )} + + + bordered + dataSource={fields.map((field) => ({ + key: String(field.key), + name: undefined, + value: undefined, + type: undefined + }))} + columns={getColumns(remove)} + pagination={false} + size="small" + locale={{ emptyText: }} + scroll={{ x: 'max-content' }} + /> + + {!title && } + + ); + }} +
); }; diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx index 3b855c48..80584220 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/index.tsx @@ -1,16 +1,18 @@ -import { type FC, useEffect, useRef } from "react"; +import { type FC, useRef } from "react"; import { useTranslation } from 'react-i18next' -import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input, Slider } from 'antd' +import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input } from 'antd' import Editor from '../../Editor' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import AuthConfigModal from './AuthConfigModal' import type { AuthConfigModalRef, HttpRequestConfigForm } from './types' import VariableSelect from "../VariableSelect"; import MessageEditor from '../MessageEditor' -import EditableTable, { type TableRow } from './EditableTable' +import EditableTable from './EditableTable' -const HttpRequest: FC<{ options: Suggestion[]; }> = ({ +const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: any; }> = ({ options, + selectedNode, + graphRef }) => { const { t } = useTranslation() const form = Form.useFormInstance(); @@ -22,29 +24,45 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ } const handleRefresh = (auth: HttpRequestConfigForm['auth']) => { console.log('handleRefresh', auth) - form.setFieldsValue({ auth: {...auth} }) + form.setFieldsValue({ auth }) } - const handleChangeBodyContentType = (contentType: string) => { - const currentValues = form.getFieldsValue() + const handleChangeBodyContentType = () => { + form.setFieldValue(['body', 'data'], undefined) + } + + const handleChangeErrorHandleMethod = (method: string) => { form.setFieldsValue({ - body: { - ...currentValues?.body, - content_type: contentType, - data: undefined + error_handle: { + method, + body: undefined, + status_code: undefined, + headers: undefined } }) - } - - const updateObjectList = (data: TableRow[], key: string) => { - let obj: Record = {} - if (data.length) { - data.forEach(vo => { - obj[vo.name] = vo.value - }) + + // 更新节点连接桩 + console.log('handleChangeErrorHandleMethod', selectedNode, graphRef?.current) + if (selectedNode && graphRef?.current) { + const existingPorts = selectedNode.getPorts(); + const errorPort = existingPorts.find((port: any) => port.id === 'ERROR'); + + if (method === 'branch' && !errorPort) { + // 添加异常节点连接桩 + selectedNode.addPort({ + id: 'ERROR', + group: 'right', + attrs: { text: { text: t('workflow.config.http-request.errorBranch'), fontSize: 12, fill: '#5B6167' }} + }); + } else if (method !== 'branch' && errorPort) { + // 移除异常节点连接桩和相关连线 + const edges = graphRef.current.getEdges().filter((edge: any) => + edge.getSourceCellId() === selectedNode.id && edge.getSourcePortId() === 'ERROR' + ); + edges.forEach((edge: any) => graphRef.current.removeCell(edge)); + selectedNode.removePort('ERROR'); + } } - - form.setFieldValue(key, obj) } console.log('HttpRequest', values) @@ -81,17 +99,19 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ updateObjectList(headers, 'headers')} + filterBooleanType={true} /> updateObjectList(params, 'params')} + filterBooleanType={true} /> @@ -113,15 +133,9 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ {values?.body?.content_type === 'form-data' && { - form.setFieldsValue({ - body: { - ...form.getFieldValue('body'), - data - } - }) - }} + filterBooleanType={true} typeOptions={[ { label: 'text', value: 'text' }, { label: 'file', value: 'file' } @@ -132,19 +146,17 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ {values?.body?.content_type === 'x-www-form-urlencoded' && { - const currentBody = form.getFieldValue('body') || {} - form.setFieldsValue({ - body: { ...currentBody, data } - }) - }} + filterBooleanType={true} /> } {values?.body?.content_type === 'json' && = ({ {values?.body?.content_type === 'raw' && = ({ } @@ -179,19 +194,31 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ name={['timeouts', 'connect_timeout']} label={t('workflow.config.http-request.connect_timeout')} > - + form.setFieldValue(['timeouts', 'connect_timeout'], value)} + /> - + form.setFieldValue(['timeouts', 'read_timeout'], value)} + /> - + form.setFieldValue(['timeouts', 'write_timeout'], value)} + /> @@ -204,13 +231,21 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ name={['retry', 'max_attempts']} label={t('workflow.config.http-request.max_attempts')} > - + form.setFieldValue(['retry', 'max_attempts'], value)} + /> {t('workflow.config.http-request.retry_interval')} (ms)} > - + form.setFieldValue(['retry', 'retry_interval'], value)} + /> } @@ -219,6 +254,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({ status_code number} > - + form.setFieldValue(['error_handle', 'status_code'], value)} + /> { - if (!value) return Promise.resolve(); - try { - JSON.parse(value); - return Promise.resolve(); - } catch { - return Promise.reject(new Error('Please enter valid JSON format')); - } - } - } - ]} + label={<>headers object} > diff --git a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx index aeafb67c..af89f7d2 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx @@ -128,29 +128,32 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi ( - -
-
- {item.name} - - {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} - -
{t('application.contains', {include_count: item.doc_num})}
+ renderItem={(item) => { + if (!item.id) return null + return ( + +
+
+ {item.name} + + {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} + +
{t('application.contains', {include_count: item.doc_num})}
+
+ +
handleEditKnowledge(item)} + >
+
handleDeleteKnowledge(item.id)} + >
+
- -
handleEditKnowledge(item)} - >
-
handleDeleteKnowledge(item.id)} - >
-
-
- - )} + + ) + }} /> } {/* 全局设置 */} diff --git a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx index e2c87ada..abf56b18 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx @@ -117,7 +117,12 @@ const KnowledgeConfigModal = forwardRef - + form.setFieldValue('top_k', value)} + /> {/* 语义相似度阈值 similarity_threshold */} {values?.retrieve_type === 'semantic' && ( diff --git a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeGlobalConfigModal.tsx b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeGlobalConfigModal.tsx index 773fb881..2f349487 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeGlobalConfigModal.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeGlobalConfigModal.tsx @@ -110,7 +110,12 @@ const KnowledgeGlobalConfigModal = forwardRef - + form.setFieldValue('reranker_top_k', value)} + /> } diff --git a/web/src/views/Workflow/components/Properties/MappingList/index.tsx b/web/src/views/Workflow/components/Properties/MappingList/index.tsx index b2066681..f35a422e 100644 --- a/web/src/views/Workflow/components/Properties/MappingList/index.tsx +++ b/web/src/views/Workflow/components/Properties/MappingList/index.tsx @@ -1,12 +1,15 @@ import React from 'react'; import { useTranslation } from 'react-i18next' import { MinusCircleOutlined } from '@ant-design/icons'; -import { Button, Form, Input, Space } from 'antd'; +import { Button, Form, Input, Space, Row, Col } from 'antd'; +import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' +import VariableSelect from '../VariableSelect' interface MappingListProps { name: string; + options: Suggestion[]; } -const MappingList: React.FC = ({ name }) => { +const MappingList: React.FC = ({ name, options }) => { const { t } = useTranslation() return ( <> @@ -14,23 +17,33 @@ const MappingList: React.FC = ({ name }) => { {(fields, { add, remove }) => ( <> {fields.map(({ key, name, ...restField }) => ( - - - - - - - - remove(name)} /> - + +
+ + + + + + + + + + + remove(name)} /> + + ))} + {title ?? t('workflow.answerDesc')} + + + + + + + ); } return ( -
- {isArray - ? - {(fields, { add, remove }) => ( - - {fields.map(({ key, name, ...restField }) => { - const currentRole = (values[parentName]?.[key].role || 'USER').toUpperCase() - - return ( - - -
- - {currentRole === 'SYSTEM' - ? - : - + ) : ( + } - {['number'].includes(values.type) && } + {['number'].includes(values.type) && ( + form.setFieldValue('default', value)} + /> + )} {['boolean'].includes(values.type) && = ({ selectedNode, graphRef, - config, + config: workflowConfig, }) => { const { t } = useTranslation() const { modal } = App.useApp() @@ -45,13 +45,134 @@ const Properties: FC = ({ const values = Form.useWatch([], form); const variableModalRef = useRef(null) const [editIndex, setEditIndex] = useState(null) + const prevMappingNamesRef = useRef([]) + const prevTemplateVarsRef = useRef([]) + const syncTimeoutRef = useRef(null) + const isSyncingRef = useRef(false) + const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null) useEffect(() => { if (selectedNode?.getData()?.id) { form.resetFields() + prevMappingNamesRef.current = [] + prevTemplateVarsRef.current = [] + lastSyncSourceRef.current = null } }, [selectedNode?.getData()?.id]) + // Sync template when mapping names change + useEffect(() => { + if (isSyncingRef.current || lastSyncSourceRef.current === 'mapping' || selectedNode?.data?.type !== 'jinja-render' || !values?.mapping || !values?.template) return + + const currentMappingNames = Array.isArray(values.mapping) ? values.mapping.map((item: any) => item.name).filter(Boolean) : [] + const prevNames = prevMappingNamesRef.current + + if (prevNames.length === 0) { + prevMappingNamesRef.current = currentMappingNames + return + } + + if (JSON.stringify(prevNames) === JSON.stringify(currentMappingNames)) return + + if (syncTimeoutRef.current) clearTimeout(syncTimeoutRef.current) + const activeElement = document.activeElement as HTMLElement + + syncTimeoutRef.current = setTimeout(() => { + let updatedTemplate = String(form.getFieldValue('template') || '') + + prevNames.forEach((oldName, index) => { + const newName = currentMappingNames[index] + if (newName && oldName !== newName) { + updatedTemplate = updatedTemplate.replace(new RegExp(`{{\\s*${oldName}\\s*}}`, 'g'), `{{${newName}}}`) + } + }) + + if (updatedTemplate !== form.getFieldValue('template')) { + isSyncingRef.current = true + lastSyncSourceRef.current = 'mapping' + const newTemplateVars = (updatedTemplate.match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, '')) + prevTemplateVarsRef.current = newTemplateVars + prevMappingNamesRef.current = currentMappingNames + form.setFieldValue('template', updatedTemplate) + + requestAnimationFrame(() => { + activeElement?.focus?.() + setTimeout(() => { + isSyncingRef.current = false + lastSyncSourceRef.current = null + }, 50) + }) + } else { + prevMappingNamesRef.current = currentMappingNames + } + }, 0) + }, [values?.mapping, selectedNode?.data?.type, form]) + + // Sync mapping when template variables change + useEffect(() => { + if (isSyncingRef.current || lastSyncSourceRef.current === 'template' || selectedNode?.data?.type !== 'jinja-render' || !values?.template || !values?.mapping) return + + const templateVars = (String(values.template).match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, '')) + if (JSON.stringify(prevTemplateVarsRef.current) === JSON.stringify(templateVars)) return + + const isTemplateEditor = document.activeElement?.closest('[data-editor-type="template"]') + if (!isTemplateEditor) { + prevTemplateVarsRef.current = templateVars + return + } + + const updatedMapping = Array.isArray(values.mapping) ? [...values.mapping] : [] + const existingNames = updatedMapping.map(item => item.name) + let updatedTemplate = String(values.template) + + if (prevTemplateVarsRef.current.length > 0) { + prevTemplateVarsRef.current.forEach((oldVar, index) => { + const newVar = templateVars[index] + if (newVar && oldVar !== newVar && updatedMapping[index]) { + updatedMapping[index] = { ...updatedMapping[index], name: newVar } + } + }) + } + + templateVars.forEach(varName => { + const existingMapping = updatedMapping.find(item => item.value === `{{${varName}}}`) + const regex = new RegExp(`{{\\s*${varName.replace(/\./g, '\\.')}\\s*}}`, 'g') + + if (existingMapping) { + updatedTemplate = updatedTemplate.replace(regex, `{{${existingMapping.name}}}`) + } else if (!existingNames.includes(varName)) { + const mappingName = varName.includes('.') ? varName.split('.').pop() || varName : varName + updatedMapping.push({ name: mappingName, value: `{{${varName}}}` }) + updatedTemplate = updatedTemplate.replace(regex, `{{${mappingName}}}`) + } + }) + + const seenNames = new Set() + const finalMapping = updatedMapping.filter(item => { + const isUsed = templateVars.some(v => item.name === v || item.value === `{{${v}}}`) + if (!isUsed || seenNames.has(item.name)) return false + seenNames.add(item.name) + return true + }) + + isSyncingRef.current = true + lastSyncSourceRef.current = 'template' + prevMappingNamesRef.current = finalMapping.map((item: any) => item.name).filter(Boolean) + prevTemplateVarsRef.current = templateVars + + if (JSON.stringify(finalMapping) !== JSON.stringify(values.mapping)) { + form.setFieldValue('mapping', finalMapping) + } + if (updatedTemplate !== String(values.template)) { + form.setFieldValue('template', updatedTemplate) + } + + setTimeout(() => { + isSyncingRef.current = false + lastSyncSourceRef.current = null + }, 50) + }, [values?.template, selectedNode?.data?.type, form]) + useEffect(() => { if (selectedNode && form) { const { type = 'default', name = '', config } = selectedNode.getData() || {} @@ -82,18 +203,9 @@ const Properties: FC = ({ useEffect(() => { if (values && selectedNode) { - const { id, knowledge_retrieval, group, group_names, ...rest } = values + const { id, knowledge_retrieval, group, group_variables, ...rest } = values const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {} - let groupNames: Record | string[] = {} - - if (group && group_names?.length) { - group_names.forEach(vo => { - (groupNames as Record)[vo.key] = vo.value - }) - } else if (!group) { - groupNames = group_names?.[0]?.value || [] - } let allRest = { ...rest, ...restKnowledgeConfig, @@ -105,9 +217,18 @@ const Properties: FC = ({ })) } + + Object.keys(values).forEach(key => { if (selectedNode.data?.config?.[key]) { - selectedNode.data.config[key].defaultValue = values[key] + // Create a deep copy to avoid reference sharing between nodes + if (!selectedNode.data.config[key]) { + selectedNode.data.config[key] = {}; + } + selectedNode.data.config[key] = { + ...selectedNode.data.config[key], + defaultValue: values[key] + }; } }) @@ -116,7 +237,7 @@ const Properties: FC = ({ ...allRest, }) } - }, [values, selectedNode]) + }, [values, selectedNode, form]) const handleAddVariable = () => { variableModalRef.current?.handleOpen() @@ -192,16 +313,95 @@ const Properties: FC = ({ .map(node => node.id); }; + // Find parent loop/iteration node if current node is a child + const getParentLoopNode = (nodeId: string): Node | null => { + const node = nodes.find(n => n.id === nodeId); + if (!node) return null; + + const nodeData = node.getData(); + const cycle = nodeData?.cycle; + + if (cycle) { + const parentNode = nodes.find(n => n.getData().id === cycle); + if (parentNode) { + const parentData = parentNode.getData(); + if (parentData?.type === 'loop' || parentData?.type === 'iteration') { + return parentNode; + } + } + } + return null; + }; + const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id); const childNodeIds = getChildNodes(selectedNode.id); - console.log('childNodeIds', childNodeIds) + const parentLoopNode = getParentLoopNode(selectedNode.id); + + console.log('childNodeIds', selectedNode, childNodeIds) const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds]; + // Add parent loop/iteration node variables if current node is a child + if (parentLoopNode) { + const parentData = parentLoopNode.getData(); + const parentNodeId = parentLoopNode.getData().id; + + if (parentData.type === 'loop') { + const cycleVars = parentData.cycle_vars || []; + cycleVars.forEach((cycleVar: any) => { + const key = `${parentNodeId}_cycle_${cycleVar.name}`; + if (!addedKeys.has(key)) { + addedKeys.add(key); + variableList.push({ + key, + label: cycleVar.name, + type: 'variable', + dataType: cycleVar.type || 'String', + value: `${parentNodeId}.${cycleVar.name}`, + nodeData: parentData, + }); + } + }); + } else if (parentData.type === 'iteration') { + // Add item and index variables for iteration parent + const itemKey = `${parentNodeId}_item`; + const indexKey = `${parentNodeId}_index`; + + if (!addedKeys.has(itemKey)) { + addedKeys.add(itemKey); + variableList.push({ + key: itemKey, + label: 'item', + type: 'variable', + dataType: 'Object', + value: `${parentNodeId}.item`, + nodeData: parentData, + }); + } + + if (!addedKeys.has(indexKey)) { + addedKeys.add(indexKey); + variableList.push({ + key: indexKey, + label: 'index', + type: 'variable', + dataType: 'Number', + value: `${parentNodeId}.index`, + nodeData: parentData, + }); + } + } + + // Add variables from nodes preceding the parent loop/iteration node + const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id); + allRelevantNodeIds.push(...parentPreviousNodeIds); + } + allRelevantNodeIds.forEach(nodeId => { const node = nodes.find(n => n.id === nodeId); if (!node) return; const nodeData = node.getData(); + const dataNodeId = nodeData.id; // Use the data.id instead of node.id for consistency switch(nodeData.type) { case 'start': @@ -211,7 +411,7 @@ const Properties: FC = ({ ] list.forEach((variable: any) => { if (!variable || !variable?.name) return; - const key = `${nodeId}_${variable.name}`; + const key = `${dataNodeId}_${variable.name}`; if (!addedKeys.has(key)) { addedKeys.add(key); variableList.push({ @@ -219,14 +419,14 @@ const Properties: FC = ({ label: variable.name, type: 'variable', dataType: variable.type, - value: `{{${nodeId}.${variable.name}}}`, + value: `${dataNodeId}.${variable.name}`, nodeData: nodeData, }); } }); nodeData.config?.variables?.sys?.forEach((variable: any) => { if (!variable || !variable?.name) return; - const key = `${nodeId}_sys_${variable.name}`; + const key = `${dataNodeId}_sys_${variable.name}`; if (!addedKeys.has(key)) { addedKeys.add(key); variableList.push({ @@ -241,7 +441,7 @@ const Properties: FC = ({ }); break case 'llm': - const llmKey = `${nodeId}_output`; + const llmKey = `${dataNodeId}_output`; if (!addedKeys.has(llmKey)) { addedKeys.add(llmKey); variableList.push({ @@ -249,13 +449,13 @@ const Properties: FC = ({ label: 'output', type: 'variable', dataType: 'String', - value: `${nodeId}.output`, + value: `${dataNodeId}.output`, nodeData: nodeData, }); } break case 'knowledge-retrieval': - const knowledgeKey = `${nodeId}_message`; + const knowledgeKey = `${dataNodeId}_message`; if (!addedKeys.has(knowledgeKey)) { addedKeys.add(knowledgeKey); variableList.push({ @@ -263,7 +463,219 @@ const Properties: FC = ({ label: 'message', type: 'variable', dataType: 'array[object]', - value: `${nodeId}.message`, + value: `${dataNodeId}.message`, + nodeData: nodeData, + }); + } + break + case 'parameter-extractor': + const successKey = `${dataNodeId}___is_success`; + const reasonKey = `${dataNodeId}___reason`; + if (!addedKeys.has(successKey)) { + addedKeys.add(successKey); + variableList.push({ + key: successKey, + label: '__is_success', + type: 'variable', + dataType: 'number', + value: `${dataNodeId}.__is_success`, + nodeData: nodeData, + }); + } + if (!addedKeys.has(reasonKey)) { + addedKeys.add(reasonKey); + variableList.push({ + key: reasonKey, + label: '__reason', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.__reason`, + nodeData: nodeData, + }); + } + // Add params variables + const paramsList = nodeData.config?.params?.defaultValue || []; + paramsList.forEach((param: any) => { + if (!param || !param?.name) return; + const paramKey = `${dataNodeId}_${param.name}`; + if (!addedKeys.has(paramKey)) { + addedKeys.add(paramKey); + variableList.push({ + key: paramKey, + label: param.name, + type: 'variable', + dataType: param.type || 'string', + value: `${dataNodeId}.${param.name}`, + nodeData: nodeData, + }); + } + }); + break + case 'var-aggregator': + if (nodeData.config.group.defaultValue) { + // If group=true, add variables from group_variables with key as variable name + const groupVariables = nodeData.config.group_variables.defaultValue || []; + groupVariables?.forEach((groupVar: any) => { + if (!groupVar || !groupVar.key) return; + const groupVarKey = `${dataNodeId}_${groupVar.key}`; + if (!addedKeys.has(groupVarKey)) { + addedKeys.add(groupVarKey); + variableList.push({ + key: groupVarKey, + label: groupVar.key, + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.${groupVar.key}`, + nodeData: nodeData, + }); + } + }); + } else { + // If group=false, add output variable + const varAggregatorKey = `${dataNodeId}_output`; + if (!addedKeys.has(varAggregatorKey)) { + addedKeys.add(varAggregatorKey); + variableList.push({ + key: varAggregatorKey, + label: 'output', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.output`, + nodeData: nodeData, + }); + } + } + break + case 'http-request': + const httpBodyKey = `${dataNodeId}_body`; + const httpStatusKey = `${dataNodeId}_status_code`; + if (!addedKeys.has(httpBodyKey)) { + addedKeys.add(httpBodyKey); + variableList.push({ + key: httpBodyKey, + label: 'body', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.body`, + nodeData: nodeData, + }); + } + if (!addedKeys.has(httpStatusKey)) { + addedKeys.add(httpStatusKey); + variableList.push({ + key: httpStatusKey, + label: 'status_code', + type: 'variable', + dataType: 'number', + value: `${dataNodeId}.status_code`, + nodeData: nodeData, + }); + } + break + case 'jinja-render': + const jinjaOutputKey = `${dataNodeId}_output`; + if (!addedKeys.has(jinjaOutputKey)) { + addedKeys.add(jinjaOutputKey); + variableList.push({ + key: jinjaOutputKey, + label: 'output', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.output`, + nodeData: nodeData, + }); + } + break + case 'question-classifier': + const classNameKey = `${dataNodeId}_class_name`; + const outputKey = `${dataNodeId}_output`; + if (!addedKeys.has(classNameKey)) { + addedKeys.add(classNameKey); + variableList.push({ + key: classNameKey, + label: 'class_name', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.class_name`, + nodeData: nodeData, + }); + } + if (!addedKeys.has(outputKey)) { + addedKeys.add(outputKey); + variableList.push({ + key: outputKey, + label: 'output', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.output`, + nodeData: nodeData, + }); + } + break + case 'iteration': + console.log('iteration addedKeys', addedKeys) + const iterationOutputKey = `${dataNodeId}_output`; + const iterationItemKey = `${dataNodeId}_item`; + if (!addedKeys.has(iterationOutputKey)) { + addedKeys.add(iterationOutputKey); + // Get the data type from the output configuration, default to string + const outputConfig = nodeData.output; + let outputDataType = 'string'; + if (outputConfig) { + // Find the selected variable from variableList to get its type + const selectedVariable = variableList.find(v => v.value === outputConfig); + if (selectedVariable) { + outputDataType = selectedVariable.dataType; + } + } + variableList.push({ + key: iterationOutputKey, + label: 'output', + type: 'variable', + dataType: outputDataType, + value: `${dataNodeId}.output`, + nodeData: nodeData, + }); + } + if (!addedKeys.has(iterationItemKey)) { + addedKeys.add(iterationItemKey); + variableList.push({ + key: iterationItemKey, + label: 'item', + type: 'variable', + dataType: 'string', + value: `${dataNodeId}.item`, + nodeData: nodeData, + }); + } + break + case 'loop': + const cycleVars = nodeData.config.cycle_vars.defaultValue || []; + cycleVars.forEach((cycleVar: any) => { + const cycleVarKey = `${dataNodeId}_cycle_${cycleVar.name}`; + if (!addedKeys.has(cycleVarKey)) { + addedKeys.add(cycleVarKey); + variableList.push({ + key: cycleVarKey, + label: cycleVar.name, + type: 'variable', + dataType: cycleVar.type || 'string', + value: `${dataNodeId}.${cycleVar.name}`, + nodeData: nodeData, + }); + } + }); + break + case 'tool': + const toolDataKey = `${dataNodeId}_data`; + if (!addedKeys.has(toolDataKey)) { + addedKeys.add(toolDataKey); + variableList.push({ + key: toolDataKey, + label: 'data', + type: 'variable', + dataType: 'object', + value: `${dataNodeId}.data`, nodeData: nodeData, }); } @@ -272,7 +684,7 @@ const Properties: FC = ({ }); // Add conversation variables from global config - const conversationVariables = config?.variables || []; + const conversationVariables = workflowConfig?.variables || []; conversationVariables.forEach((variable: any) => { const key = `CONVERSATION_${variable.name}`; @@ -283,7 +695,7 @@ const Properties: FC = ({ label: variable.name, type: 'variable', dataType: variable.type, - value: `conversation.${variable.name}`, + value: `conv.${variable.name}`, nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' }, group: 'CONVERSATION' }); @@ -291,7 +703,15 @@ const Properties: FC = ({ }); return variableList; - }, [selectedNode, graphRef]); + }, [selectedNode, graphRef, workflowConfig?.variables]); + + // Filter out boolean type variables for loop and llm nodes + const getFilteredVariableList = (nodeType?: string) => { + if (nodeType === 'loop' || nodeType === 'llm') { + return variableList.filter(variable => variable.dataType !== 'boolean'); + } + return variableList; + }; console.log('values', values) console.log('variableList', variableList, selectedNode?.data) @@ -317,6 +737,8 @@ const Properties: FC = ({ {selectedNode?.data?.type === 'http-request' ? : selectedNode?.data?.type === 'tool' ? @@ -374,7 +796,7 @@ const Properties: FC = ({ if (selectedNode?.data?.type === 'llm' && key === 'messages' && config.type === 'define') { // 为llm节点且isArray=true时添加context变量支持 - let contextVariableList = [...variableList]; + let contextVariableList = [...getFilteredVariableList('llm')]; const isArrayMode = config.isArray !== false; // 默认为true if (isArrayMode) { @@ -387,7 +809,7 @@ const Properties: FC = ({ label: 'context', type: 'variable', dataType: 'String', - value: `{{context}}`, + value: `context`, nodeData: selectedNode.getData(), isContext: true, }); @@ -396,14 +818,14 @@ const Properties: FC = ({ return ( - + ) } if (selectedNode?.data?.type === 'end' && key === 'output') { return ( - + ) } @@ -430,7 +852,8 @@ const Properties: FC = ({ title={t(`workflow.config.${selectedNode?.data?.type}.${key}`)} isArray={!!config.isArray} parentName={key} - options={variableList} + enableJinja2={config.enableJinja2 as boolean} + options={getFilteredVariableList(selectedNode?.data?.type)} /> ) @@ -451,7 +874,7 @@ const Properties: FC = ({ @@ -463,7 +886,7 @@ const Properties: FC = ({ @@ -476,7 +899,7 @@ const Properties: FC = ({ - + ) @@ -486,7 +909,7 @@ const Properties: FC = ({ ) @@ -498,71 +921,11 @@ const Properties: FC = ({ parentName={key} options={(() => { if (config.filterLoopIterationVars) { - // Add loop cycle variables and iteration item/index variables const loopIterationVars: Suggestion[] = []; - const graph = graphRef.current; - if (graph && selectedNode) { - const nodes = graph.getNodes(); - - // Find parent loop/iteration nodes - const findParentLoopIteration = (nodeId: string): string[] => { - const node = nodes.find(n => n.id === nodeId); - if (!node) return []; - - const nodeData = node.getData(); - const cycle = nodeData?.cycle; - - if (cycle) { - const parentNode = nodes.find(n => n.getData().id === cycle); - if (parentNode) { - const parentData = parentNode.getData(); - if (parentData?.type === 'loop') { - console.log('parentData', parentData) - // Add cycle variables from loop node - const cycleVars = parentData.cycle_vars || []; - cycleVars.forEach((cycleVar: any) => { - loopIterationVars.push({ - key: `${cycle}_cycle_${cycleVar.name}`, - label: cycleVar.name, - type: 'variable', - dataType: 'String', - value: `${cycle}.${cycleVar.name}`, - nodeData: parentData, - }); - }); - } else if (parentData?.type === 'iteration') { - // Add item and index variables from iteration node - loopIterationVars.push( - { - key: `${cycle}_item`, - label: 'item', - type: 'variable', - dataType: 'Object', - value: `${cycle}.item`, - nodeData: parentData, - }, - { - key: `${cycle}_index`, - label: 'index', - type: 'variable', - dataType: 'Number', - value: `${cycle}.index`, - nodeData: parentData, - } - ); - } - return [cycle, ...findParentLoopIteration(cycle)]; - } - } - return []; - }; - - findParentLoopIteration(selectedNode.id); - } - return [...variableList, ...loopIterationVars]; + return [...getFilteredVariableList(selectedNode?.data?.type), ...loopIterationVars]; } - return variableList; + return getFilteredVariableList(selectedNode?.data?.type); })() } /> @@ -583,11 +946,15 @@ const Properties: FC = ({ ? : config.type === 'select' ?